23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
30 #define DEBUG_TYPE "vector-narrow-type-emulation"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define DBGSNL() (llvm::dbgs() << "\n")
33 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
45 int origElements,
int scale) {
46 auto numElements = (origElements + scale - 1) / scale;
51 while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
52 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
53 maskOp = extractOp.getVector().getDefiningOp();
54 extractOps.push_back(extractOp);
57 auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
58 auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
59 if (!createMaskOp && !constantMaskOp)
67 shape.back() = numElements;
71 size_t numMaskOperands = maskOperands.size();
81 newMaskOperands.push_back(
83 newMask = rewriter.
create<vector::CreateMaskOp>(loc, newMaskType,
85 }
else if (constantMaskOp) {
87 constantMaskOp.getMaskDimSizes().getValue();
88 size_t numMaskOperands = maskDimSizes.size();
90 cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
91 IntegerAttr maskIndexAttr =
94 newMaskDimSizes.push_back(maskIndexAttr);
95 newMask = rewriter.
create<vector::ConstantMaskOp>(
96 loc, newMaskType, rewriter.
getArrayAttr(newMaskDimSizes));
99 while (!extractOps.empty()) {
100 newMask = rewriter.
create<vector::ExtractOp>(
101 loc, newMask->
getResults()[0], extractOps.back().getMixedPosition());
102 extractOps.pop_back();
118 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
122 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
123 Type oldElementType = op.getValueToStore().getType().getElementType();
124 Type newElementType = convertedType.getElementType();
128 if (dstBits % srcBits != 0) {
130 op,
"only dstBits % srcBits == 0 supported");
132 int scale = dstBits / srcBits;
147 auto origElements = op.getValueToStore().getType().getNumElements();
148 if (origElements % scale != 0)
151 auto stridedMetadata =
152 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
155 std::tie(std::ignore, linearizedIndices) =
157 rewriter, loc, srcBits, dstBits,
158 stridedMetadata.getConstifiedMixedOffset(),
159 stridedMetadata.getConstifiedMixedSizes(),
160 stridedMetadata.getConstifiedMixedStrides(),
163 auto numElements = origElements / scale;
164 auto bitCast = rewriter.
create<vector::BitCastOp>(
166 op.getValueToStore());
169 op, bitCast.
getResult(), adaptor.getBase(),
179 struct ConvertVectorMaskedStore final
184 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
188 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
189 Type oldElementType = op.getValueToStore().getType().getElementType();
190 Type newElementType = convertedType.getElementType();
194 if (dstBits % srcBits != 0) {
196 op,
"only dstBits % srcBits == 0 supported");
199 int scale = dstBits / srcBits;
200 int origElements = op.getValueToStore().getType().getNumElements();
201 if (origElements % scale != 0)
204 auto stridedMetadata =
205 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
207 std::tie(std::ignore, linearizedIndicesOfr) =
209 rewriter, loc, srcBits, dstBits,
210 stridedMetadata.getConstifiedMixedOffset(),
211 stridedMetadata.getConstifiedMixedSizes(),
212 stridedMetadata.getConstifiedMixedStrides(),
214 Value linearizedIndices =
235 FailureOr<Operation *> newMask =
240 auto numElements = (origElements + scale - 1) / scale;
242 auto passThru = rewriter.
create<arith::ConstantOp>(
245 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
246 loc, newType, adaptor.getBase(), linearizedIndices,
247 newMask.value()->getResult(0), passThru);
249 Value valueToStore = rewriter.
create<vector::BitCastOp>(
250 loc, op.getValueToStore().getType(), newLoad);
251 valueToStore = rewriter.
create<arith::SelectOp>(
252 loc, op.getMask(), op.getValueToStore(), valueToStore);
254 rewriter.
create<vector::BitCastOp>(loc, newType, valueToStore);
257 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
271 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
275 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
276 Type oldElementType = op.getType().getElementType();
277 Type newElementType = convertedType.getElementType();
281 if (dstBits % srcBits != 0) {
283 op,
"only dstBits % srcBits == 0 supported");
285 int scale = dstBits / srcBits;
304 auto origElements = op.getVectorType().getNumElements();
305 if (origElements % scale != 0)
308 auto stridedMetadata =
309 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
312 std::tie(std::ignore, linearizedIndices) =
314 rewriter, loc, srcBits, dstBits,
315 stridedMetadata.getConstifiedMixedOffset(),
316 stridedMetadata.getConstifiedMixedSizes(),
317 stridedMetadata.getConstifiedMixedStrides(),
320 auto numElements = (origElements + scale - 1) / scale;
321 auto newLoad = rewriter.
create<vector::LoadOp>(
326 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newLoad);
328 rewriter.
replaceOp(op, bitCast->getResult(0));
337 struct ConvertVectorMaskedLoad final
342 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
346 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
347 Type oldElementType = op.getType().getElementType();
348 Type newElementType = convertedType.getElementType();
352 if (dstBits % srcBits != 0) {
354 op,
"only dstBits % srcBits == 0 supported");
356 int scale = dstBits / srcBits;
400 auto origType = op.getVectorType();
401 auto origElements = origType.getNumElements();
402 if (origElements % scale != 0)
405 auto stridedMetadata =
406 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
408 std::tie(std::ignore, linearizedIndices) =
410 rewriter, loc, srcBits, dstBits,
411 stridedMetadata.getConstifiedMixedOffset(),
412 stridedMetadata.getConstifiedMixedSizes(),
413 stridedMetadata.getConstifiedMixedStrides(),
416 FailureOr<Operation *> newMask =
421 auto numElements = (origElements + scale - 1) / scale;
424 rewriter.
create<vector::BitCastOp>(loc, newType, op.getPassThru());
427 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
428 loc, newType, adaptor.getBase(),
430 newMask.value()->getResult(0), newPassThru);
435 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newLoad);
436 auto select = rewriter.
create<arith::SelectOp>(loc, op.getMask(), bitCast,
438 rewriter.
replaceOp(op, select->getResult(0));
448 struct ConvertVectorTransferRead final
453 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
457 auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
458 Type oldElementType = op.getType().getElementType();
459 Type newElementType = convertedType.getElementType();
463 if (dstBits % srcBits != 0) {
465 op,
"only dstBits % srcBits == 0 supported");
467 int scale = dstBits / srcBits;
469 auto origElements = op.getVectorType().getNumElements();
470 if (origElements % scale != 0)
473 auto newPadding = rewriter.
create<arith::ExtUIOp>(loc, newElementType,
474 adaptor.getPadding());
476 auto stridedMetadata =
477 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
480 std::tie(std::ignore, linearizedIndices) =
482 rewriter, loc, srcBits, dstBits,
483 stridedMetadata.getConstifiedMixedOffset(),
484 stridedMetadata.getConstifiedMixedSizes(),
485 stridedMetadata.getConstifiedMixedStrides(),
488 auto numElements = (origElements + scale - 1) / scale;
491 auto newRead = rewriter.
create<vector::TransferReadOp>(
492 loc, newReadType, adaptor.getSource(),
497 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newRead);
499 rewriter.
replaceOp(op, bitCast->getResult(0));
513 struct SourceElementRange {
515 int64_t sourceElementIdx;
517 int64_t sourceBitBegin;
518 int64_t sourceBitEnd;
521 struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
527 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
529 for (int64_t i = 0; i < shuffleIdx; ++i)
530 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
549 struct BitCastBitsEnumerator {
550 BitCastBitsEnumerator(VectorType sourceVectorType,
551 VectorType targetVectorType);
553 int64_t getMaxNumberOfEntries() {
554 int64_t numVectors = 0;
555 for (
const auto &l : sourceElementRanges)
556 numVectors =
std::max(numVectors, (int64_t)l.size());
560 VectorType sourceVectorType;
561 VectorType targetVectorType;
636 struct BitCastRewriter {
643 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
647 VectorType preconditionType,
Operation *op);
651 precomputeMetadata(IntegerType shuffledElementType);
657 const BitCastRewriter::Metadata &metadata);
662 BitCastBitsEnumerator enumerator;
667 [[maybe_unused]]
static raw_ostream &
669 for (
const auto &l : vec) {
671 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
672 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
673 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
680 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
681 VectorType targetVectorType)
682 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
684 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
685 "requires -D non-scalable vector type");
686 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
687 "requires -D non-scalable vector type");
688 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
689 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
690 LDBG(
"sourceVectorType: " << sourceVectorType);
692 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
693 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
694 LDBG(
"targetVectorType: " << targetVectorType);
696 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
697 (void)mostMinorSourceDim;
698 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
699 "source and target bitwidths must match");
703 for (int64_t resultBit = 0; resultBit < bitwidth;) {
704 int64_t resultElement = resultBit / targetBitWidth;
705 int64_t resultBitInElement = resultBit % targetBitWidth;
706 int64_t sourceElementIdx = resultBit / sourceBitWidth;
707 int64_t sourceBitInElement = resultBit % sourceBitWidth;
708 int64_t step =
std::min(sourceBitWidth - sourceBitInElement,
709 targetBitWidth - resultBitInElement);
710 sourceElementRanges[resultElement].push_back(
711 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
716 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
717 VectorType targetVectorType)
718 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
719 LDBG(
"\n" << enumerator.sourceElementRanges);
725 VectorType preconditionType,
727 if (!preconditionType || preconditionType.isScalable())
732 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
733 if (bitwidth % 8 != 0)
739 LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
740 VectorType preconditionType,
742 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
745 if (!preconditionType || preconditionType.getRank() != 1)
761 if (!srcType || !dstType)
763 unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
764 unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
767 if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
768 (dstElemBitwidth % srcElemBitwidth) != 0)
771 if ((srcType.getShape().back() % 2) != 0)
773 op,
"Not an even number of i4 elements in trailing dim");
779 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
781 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
782 shuffleIdx < e; ++shuffleIdx) {
787 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
788 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
789 ? srcEltRangeList[shuffleIdx].sourceElementIdx
791 shuffles.push_back(sourceElement);
793 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
794 ? srcEltRangeList[shuffleIdx].sourceBitBegin
796 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
797 ? srcEltRangeList[shuffleIdx].sourceBitEnd
801 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
803 masks.push_back(mask);
805 int64_t shiftRight = bitLo;
806 shiftRightAmounts.push_back(
809 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
810 shiftLeftAmounts.push_back(
814 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
819 Value BitCastRewriter::genericRewriteStep(
821 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
823 auto shuffleOp = rewriter.
create<vector::ShuffleOp>(
824 loc, initialValue, initialValue, metadata.shuffles);
827 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
828 auto constOp = rewriter.
create<arith::ConstantOp>(
830 Value andValue = rewriter.
create<arith::AndIOp>(loc, shuffleOp, constOp);
833 auto shiftRightConstantOp = rewriter.
create<arith::ConstantOp>(
837 rewriter.
create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
840 auto shiftLeftConstantOp = rewriter.
create<arith::ConstantOp>(
844 rewriter.
create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
848 ? rewriter.
create<arith::OrIOp>(loc, runningResult, shiftedLeft)
851 return runningResult;
859 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
860 assert(srcVecType.getElementType().isSignlessInteger(4) &&
865 constexpr int64_t i4Toi8BitwidthFactor = 2;
866 i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
868 Value i8Vector = rewriter.
create<vector::BitCastOp>(loc, i8VecType, srcValue);
872 constexpr int8_t bitsToShift = 4;
873 auto shiftValues = rewriter.
create<arith::ConstantOp>(
875 Value shl = rewriter.
create<arith::ShLIOp>(loc, i8Vector, shiftValues);
876 Value low = rewriter.
create<arith::ShRSIOp>(loc, shl, shiftValues);
877 Value high = rewriter.
create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
880 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
888 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
889 assert(srcVecType.getElementType().isSignlessInteger(4) &&
894 constexpr int64_t i4Toi8BitwidthFactor = 2;
895 i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
897 Value i8Vector = rewriter.
create<vector::BitCastOp>(loc, i8VecType, srcValue);
901 constexpr uint8_t lowBitsMask = 15;
902 auto lowBitsMaskValues = rewriter.
create<arith::ConstantOp>(
904 Value low = rewriter.
create<arith::AndIOp>(loc, i8VecType, i8Vector,
906 constexpr int8_t highBitsToShift = 4;
907 auto highShiftValues = rewriter.
create<arith::ConstantOp>(
909 Value high = rewriter.
create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
912 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
920 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
921 assert(srcVecType.getElementType().isSignlessInteger(8) &&
925 auto deinterleaveOp = rewriter.
create<vector::DeinterleaveOp>(loc, srcValue);
928 constexpr int8_t i8LowBitMask = 0x0F;
929 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
930 Value zeroOutMask = rewriter.
create<arith::ConstantOp>(
933 loc, deinterleaveOp.getRes1(), zeroOutMask);
936 constexpr int8_t bitsToShift = 4;
937 auto shiftValues = rewriter.
create<arith::ConstantOp>(
939 Value shlHigh = rewriter.
create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
943 auto mergedHiLowOp = rewriter.
create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
946 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
947 return rewriter.
create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
961 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
966 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
967 VectorType targetVectorType = bitCastOp.getResultVectorType();
968 BitCastRewriter bcr(sourceVectorType, targetVectorType);
969 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
973 Value truncValue = truncOp.getIn();
974 auto shuffledElementType =
977 for (
const BitCastRewriter ::Metadata &metadata :
978 bcr.precomputeMetadata(shuffledElementType)) {
979 runningResult = bcr.genericRewriteStep(
980 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
984 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
985 shuffledElementType.getIntOrFloatBitWidth();
987 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
988 rewriter.
replaceOp(bitCastOp, runningResult);
991 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
994 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
995 rewriter.
replaceOp(bitCastOp, runningResult);
998 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1015 template <
typename ExtOpType>
1025 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1030 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1031 VectorType targetVectorType = bitCastOp.getResultVectorType();
1032 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1033 if (failed(bcr.commonPrecondition(
1034 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1038 Value runningResult;
1039 Value sourceValue = bitCastOp.getSource();
1040 auto shuffledElementType =
1042 for (
const BitCastRewriter::Metadata &metadata :
1043 bcr.precomputeMetadata(shuffledElementType)) {
1044 runningResult = bcr.genericRewriteStep(
1045 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1050 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1051 shuffledElementType.getIntOrFloatBitWidth();
1054 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1057 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1097 template <
typename ConversionOpType,
bool isSigned>
1104 Value srcValue = conversionOp.getIn();
1105 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1106 auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1129 conversionOp, conversionOp.getType(), subByteExt);
1156 Value srcValue = truncOp.getIn();
1157 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1158 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1159 if (!srcVecType || !dstVecType)
1173 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
1175 rewriter.
create<arith::TruncIOp>(loc, i8VecType, srcValue);
1181 rewriter.
replaceOp(truncOp, subByteTrunc);
1206 constexpr
unsigned minNativeBitwidth = 8;
1207 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1208 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1209 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1211 "not a sub-byte transpose");
1215 Location loc = transposeOp.getLoc();
1220 auto srcNativeVecType = srcSubByteVecType.cloneWith(
1222 Value extOp = rewriter.
create<arith::ExtSIOp>(loc, srcNativeVecType,
1223 transposeOp.getVector());
1224 Value newTranspose = rewriter.
create<vector::TransposeOp>(
1225 loc, extOp, transposeOp.getPermutation());
1226 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1244 patterns.
add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1245 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1251 patterns.
add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1252 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.
getContext(),
1257 patterns.
add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
1258 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
1259 RewriteAlignedSubByteIntTrunc>(patterns.
getContext(),
1261 patterns.
add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>>(
1267 patterns.
add<RewriteVectorTranspose>(patterns.
getContext(), benefit);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops that take advantage of ...
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 signed extension into a sequence of shuffles and bitwise ops that take advantage...
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int origElements, int scale)
Returns a compressed mask.
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and bitwise ops that take advanta...
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType srcType, VectorType dstType, Operation *op)
Verify that source and destination element types meet the precondition for the supported aligned conv...
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...