23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
30 #define DEBUG_TYPE "vector-narrow-type-emulation"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define DBGSNL() (llvm::dbgs() << "\n")
33 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
45 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
49 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
50 Type oldElementType = op.getValueToStore().getType().getElementType();
51 Type newElementType = convertedType.getElementType();
55 if (dstBits % srcBits != 0) {
57 op,
"only dstBits % srcBits == 0 supported");
59 int scale = dstBits / srcBits;
74 auto origElements = op.getValueToStore().getType().getNumElements();
75 if (origElements % scale != 0)
78 auto stridedMetadata =
79 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
82 std::tie(std::ignore, linearizedIndices) =
84 rewriter, loc, srcBits, dstBits,
85 stridedMetadata.getConstifiedMixedOffset(),
86 stridedMetadata.getConstifiedMixedSizes(),
87 stridedMetadata.getConstifiedMixedStrides(),
90 auto numElements = origElements / scale;
91 auto bitCast = rewriter.
create<vector::BitCastOp>(
93 op.getValueToStore());
96 op, bitCast.
getResult(), adaptor.getBase(),
110 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
114 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
115 Type oldElementType = op.getType().getElementType();
116 Type newElementType = convertedType.getElementType();
120 if (dstBits % srcBits != 0) {
122 op,
"only dstBits % srcBits == 0 supported");
124 int scale = dstBits / srcBits;
143 auto origElements = op.getVectorType().getNumElements();
144 if (origElements % scale != 0)
147 auto stridedMetadata =
148 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
151 std::tie(std::ignore, linearizedIndices) =
153 rewriter, loc, srcBits, dstBits,
154 stridedMetadata.getConstifiedMixedOffset(),
155 stridedMetadata.getConstifiedMixedSizes(),
156 stridedMetadata.getConstifiedMixedStrides(),
159 auto numElements = (origElements + scale - 1) / scale;
160 auto newLoad = rewriter.
create<vector::LoadOp>(
165 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newLoad);
176 struct ConvertVectorMaskedLoad final
181 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
185 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
186 Type oldElementType = op.getType().getElementType();
187 Type newElementType = convertedType.getElementType();
191 if (dstBits % srcBits != 0) {
193 op,
"only dstBits % srcBits == 0 supported");
195 int scale = dstBits / srcBits;
240 auto origType = op.getVectorType();
241 auto origElements = origType.getNumElements();
242 if (origElements % scale != 0)
245 auto stridedMetadata =
246 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
249 std::tie(std::ignore, linearizedIndices) =
251 rewriter, loc, srcBits, dstBits,
252 stridedMetadata.getConstifiedMixedOffset(),
253 stridedMetadata.getConstifiedMixedSizes(),
254 stridedMetadata.getConstifiedMixedStrides(),
257 auto numElements = (origElements + scale - 1) / scale;
260 auto maskOp = op.getMask().getDefiningOp();
264 !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
265 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
266 maskOp = extractOp.getVector().getDefiningOp();
267 extractOps.push_back(extractOp);
270 auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
271 auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
272 if (!createMaskOp && !constantMaskOp)
278 auto shape = llvm::to_vector(
279 maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
280 shape.push_back(numElements);
283 auto maskOperands = createMaskOp.getOperands();
284 auto numMaskOperands = maskOperands.size();
293 auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
294 newMaskOperands.push_back(
296 newMask = rewriter.
create<vector::CreateMaskOp>(loc, newMaskType,
298 }
else if (constantMaskOp) {
299 auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
300 auto numMaskOperands = maskDimSizes.size();
302 cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
305 auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
306 newMaskDimSizes.push_back(maskIndex);
307 newMask = rewriter.
create<vector::ConstantMaskOp>(
308 loc, newMaskType, rewriter.
getArrayAttr(newMaskDimSizes));
311 while (!extractOps.empty()) {
312 newMask = rewriter.
create<vector::ExtractOp>(
313 loc, newMask->
getResults()[0], extractOps.back().getMixedPosition());
314 extractOps.pop_back();
318 rewriter.
create<vector::BitCastOp>(loc, newType, op.getPassThru());
321 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
322 loc, newType, adaptor.getBase(),
329 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newLoad);
330 auto select = rewriter.
create<arith::SelectOp>(loc, op.getMask(), bitCast,
342 struct ConvertVectorTransferRead final
347 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
351 auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
352 Type oldElementType = op.getType().getElementType();
353 Type newElementType = convertedType.getElementType();
357 if (dstBits % srcBits != 0) {
359 op,
"only dstBits % srcBits == 0 supported");
361 int scale = dstBits / srcBits;
363 auto origElements = op.getVectorType().getNumElements();
364 if (origElements % scale != 0)
367 auto newPadding = rewriter.
create<arith::ExtUIOp>(loc, newElementType,
368 adaptor.getPadding());
370 auto stridedMetadata =
371 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
374 std::tie(std::ignore, linearizedIndices) =
376 rewriter, loc, srcBits, dstBits,
377 stridedMetadata.getConstifiedMixedOffset(),
378 stridedMetadata.getConstifiedMixedSizes(),
379 stridedMetadata.getConstifiedMixedStrides(),
382 auto numElements = (origElements + scale - 1) / scale;
385 auto newRead = rewriter.
create<vector::TransferReadOp>(
386 loc, newReadType, adaptor.getSource(),
391 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newRead);
407 struct SourceElementRange {
409 int64_t sourceElementIdx;
411 int64_t sourceBitBegin;
412 int64_t sourceBitEnd;
415 struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
421 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
423 for (int64_t i = 0; i < shuffleIdx; ++i)
424 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
443 struct BitCastBitsEnumerator {
444 BitCastBitsEnumerator(VectorType sourceVectorType,
445 VectorType targetVectorType);
447 int64_t getMaxNumberOfEntries() {
448 int64_t numVectors = 0;
449 for (
const auto &l : sourceElementRanges)
450 numVectors =
std::max(numVectors, (int64_t)l.size());
454 VectorType sourceVectorType;
455 VectorType targetVectorType;
530 struct BitCastRewriter {
537 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
541 VectorType preconditionVectorType,
Operation *op);
545 precomputeMetadata(IntegerType shuffledElementType);
551 const BitCastRewriter::Metadata &metadata);
556 BitCastBitsEnumerator enumerator;
561 [[maybe_unused]]
static raw_ostream &
operator<<(raw_ostream &os,
563 for (
const auto &l : vec) {
565 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
566 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
567 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
574 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
575 VectorType targetVectorType)
576 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
578 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
579 "requires -D non-scalable vector type");
580 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
581 "requires -D non-scalable vector type");
582 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
583 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
584 LDBG(
"sourceVectorType: " << sourceVectorType);
586 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
587 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
588 LDBG(
"targetVectorType: " << targetVectorType);
590 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
591 (void)mostMinorSourceDim;
592 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
593 "source and target bitwidths must match");
597 for (int64_t resultBit = 0; resultBit < bitwidth;) {
598 int64_t resultElement = resultBit / targetBitWidth;
599 int64_t resultBitInElement = resultBit % targetBitWidth;
600 int64_t sourceElementIdx = resultBit / sourceBitWidth;
601 int64_t sourceBitInElement = resultBit % sourceBitWidth;
602 int64_t step =
std::min(sourceBitWidth - sourceBitInElement,
603 targetBitWidth - resultBitInElement);
604 sourceElementRanges[resultElement].push_back(
605 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
610 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
611 VectorType targetVectorType)
612 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
613 LDBG(
"\n" << enumerator.sourceElementRanges);
617 VectorType precondition,
619 if (precondition.getRank() != 1 || precondition.isScalable())
624 int64_t resultBitwidth = precondition.getElementTypeBitWidth();
625 if (resultBitwidth % 8 != 0)
632 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
634 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
635 shuffleIdx < e; ++shuffleIdx) {
640 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
641 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
642 ? srcEltRangeList[shuffleIdx].sourceElementIdx
644 shuffles.push_back(sourceElement);
646 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
647 ? srcEltRangeList[shuffleIdx].sourceBitBegin
649 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
650 ? srcEltRangeList[shuffleIdx].sourceBitEnd
654 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
656 masks.push_back(mask);
658 int64_t shiftRight = bitLo;
659 shiftRightAmounts.push_back(
662 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
663 shiftLeftAmounts.push_back(
667 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
674 const BitCastRewriter::Metadata &metadata) {
676 auto shuffleOp = rewriter.
create<vector::ShuffleOp>(
677 loc, initialValue, initialValue, metadata.shuffles);
680 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
681 auto constOp = rewriter.
create<arith::ConstantOp>(
683 Value andValue = rewriter.
create<arith::AndIOp>(loc, shuffleOp, constOp);
686 auto shiftRightConstantOp = rewriter.
create<arith::ConstantOp>(
690 rewriter.
create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
693 auto shiftLeftConstantOp = rewriter.
create<arith::ConstantOp>(
697 rewriter.
create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
701 ? rewriter.
create<arith::OrIOp>(loc, runningResult, shiftedLeft)
704 return runningResult;
718 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
723 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
724 VectorType targetVectorType = bitCastOp.getResultVectorType();
725 BitCastRewriter bcr(sourceVectorType, targetVectorType);
726 if (
failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
730 Value truncValue = truncOp.getIn();
731 auto shuffledElementType =
734 for (
const BitCastRewriter ::Metadata &metadata :
735 bcr.precomputeMetadata(shuffledElementType)) {
736 runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
737 runningResult, metadata);
741 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
742 shuffledElementType.getIntOrFloatBitWidth();
745 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
748 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
764 template <
typename ExtOpType>
774 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
779 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
780 VectorType targetVectorType = bitCastOp.getResultVectorType();
781 BitCastRewriter bcr(sourceVectorType, targetVectorType);
782 if (
failed(bcr.precondition(
783 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
788 Value sourceValue = bitCastOp.getSource();
789 auto shuffledElementType =
791 for (
const BitCastRewriter::Metadata &metadata :
792 bcr.precomputeMetadata(shuffledElementType)) {
793 runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
794 sourceValue, runningResult, metadata);
799 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
800 shuffledElementType.getIntOrFloatBitWidth();
803 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
806 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
823 patterns.
add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
824 ConvertVectorTransferRead>(typeConverter, patterns.
getContext());
829 patterns.
add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
830 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.
getContext(),
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...