23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
30 #define DEBUG_TYPE "vector-narrow-type-emulation"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define DBGSNL() (llvm::dbgs() << "\n")
33 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
45 int origElements,
int scale) {
46 auto numElements = (origElements + scale - 1) / scale;
51 while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
52 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
53 maskOp = extractOp.getVector().getDefiningOp();
54 extractOps.push_back(extractOp);
57 auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
58 auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
59 if (!createMaskOp && !constantMaskOp)
67 shape.back() = numElements;
71 size_t numMaskOperands = maskOperands.size();
81 newMaskOperands.push_back(
83 newMask = rewriter.
create<vector::CreateMaskOp>(loc, newMaskType,
85 }
else if (constantMaskOp) {
87 size_t numMaskOperands = maskDimSizes.size();
88 int64_t origIndex = maskDimSizes[numMaskOperands - 1];
89 int64_t maskIndex = (origIndex + scale - 1) / scale;
91 newMaskDimSizes.push_back(maskIndex);
92 newMask = rewriter.
create<vector::ConstantMaskOp>(loc, newMaskType,
96 while (!extractOps.empty()) {
97 newMask = rewriter.
create<vector::ExtractOp>(
98 loc, newMask->
getResults()[0], extractOps.back().getMixedPosition());
99 extractOps.pop_back();
115 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
119 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
120 Type oldElementType = op.getValueToStore().getType().getElementType();
121 Type newElementType = convertedType.getElementType();
125 if (dstBits % srcBits != 0) {
127 op,
"only dstBits % srcBits == 0 supported");
129 int scale = dstBits / srcBits;
144 auto origElements = op.getValueToStore().getType().getNumElements();
145 if (origElements % scale != 0)
148 auto stridedMetadata =
149 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
152 std::tie(std::ignore, linearizedIndices) =
154 rewriter, loc, srcBits, dstBits,
155 stridedMetadata.getConstifiedMixedOffset(),
156 stridedMetadata.getConstifiedMixedSizes(),
157 stridedMetadata.getConstifiedMixedStrides(),
160 auto numElements = origElements / scale;
161 auto bitCast = rewriter.
create<vector::BitCastOp>(
163 op.getValueToStore());
166 op, bitCast.
getResult(), adaptor.getBase(),
176 struct ConvertVectorMaskedStore final
181 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
185 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
186 Type oldElementType = op.getValueToStore().getType().getElementType();
187 Type newElementType = convertedType.getElementType();
191 if (dstBits % srcBits != 0) {
193 op,
"only dstBits % srcBits == 0 supported");
196 int scale = dstBits / srcBits;
197 int origElements = op.getValueToStore().getType().getNumElements();
198 if (origElements % scale != 0)
201 auto stridedMetadata =
202 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
204 std::tie(std::ignore, linearizedIndicesOfr) =
206 rewriter, loc, srcBits, dstBits,
207 stridedMetadata.getConstifiedMixedOffset(),
208 stridedMetadata.getConstifiedMixedSizes(),
209 stridedMetadata.getConstifiedMixedStrides(),
211 Value linearizedIndices =
232 FailureOr<Operation *> newMask =
237 auto numElements = (origElements + scale - 1) / scale;
239 auto passThru = rewriter.
create<arith::ConstantOp>(
242 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
243 loc, newType, adaptor.getBase(), linearizedIndices,
244 newMask.value()->getResult(0), passThru);
246 Value valueToStore = rewriter.
create<vector::BitCastOp>(
247 loc, op.getValueToStore().getType(), newLoad);
248 valueToStore = rewriter.
create<arith::SelectOp>(
249 loc, op.getMask(), op.getValueToStore(), valueToStore);
251 rewriter.
create<vector::BitCastOp>(loc, newType, valueToStore);
254 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
268 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
272 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
273 Type oldElementType = op.getType().getElementType();
274 Type newElementType = convertedType.getElementType();
278 if (dstBits % srcBits != 0) {
280 op,
"only dstBits % srcBits == 0 supported");
282 int scale = dstBits / srcBits;
301 auto origElements = op.getVectorType().getNumElements();
302 if (origElements % scale != 0)
305 auto stridedMetadata =
306 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
309 std::tie(std::ignore, linearizedIndices) =
311 rewriter, loc, srcBits, dstBits,
312 stridedMetadata.getConstifiedMixedOffset(),
313 stridedMetadata.getConstifiedMixedSizes(),
314 stridedMetadata.getConstifiedMixedStrides(),
317 auto numElements = (origElements + scale - 1) / scale;
318 auto newLoad = rewriter.
create<vector::LoadOp>(
323 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newLoad);
325 rewriter.
replaceOp(op, bitCast->getResult(0));
334 struct ConvertVectorMaskedLoad final
339 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
343 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
344 Type oldElementType = op.getType().getElementType();
345 Type newElementType = convertedType.getElementType();
349 if (dstBits % srcBits != 0) {
351 op,
"only dstBits % srcBits == 0 supported");
353 int scale = dstBits / srcBits;
397 auto origType = op.getVectorType();
398 auto origElements = origType.getNumElements();
399 if (origElements % scale != 0)
402 auto stridedMetadata =
403 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
405 std::tie(std::ignore, linearizedIndices) =
407 rewriter, loc, srcBits, dstBits,
408 stridedMetadata.getConstifiedMixedOffset(),
409 stridedMetadata.getConstifiedMixedSizes(),
410 stridedMetadata.getConstifiedMixedStrides(),
413 FailureOr<Operation *> newMask =
418 auto numElements = (origElements + scale - 1) / scale;
421 rewriter.
create<vector::BitCastOp>(loc, newType, op.getPassThru());
424 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
425 loc, newType, adaptor.getBase(),
427 newMask.value()->getResult(0), newPassThru);
432 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newLoad);
433 auto select = rewriter.
create<arith::SelectOp>(loc, op.getMask(), bitCast,
435 rewriter.
replaceOp(op, select->getResult(0));
445 struct ConvertVectorTransferRead final
450 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
454 auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
455 Type oldElementType = op.getType().getElementType();
456 Type newElementType = convertedType.getElementType();
460 if (dstBits % srcBits != 0) {
462 op,
"only dstBits % srcBits == 0 supported");
464 int scale = dstBits / srcBits;
466 auto origElements = op.getVectorType().getNumElements();
467 if (origElements % scale != 0)
470 auto newPadding = rewriter.
create<arith::ExtUIOp>(loc, newElementType,
471 adaptor.getPadding());
473 auto stridedMetadata =
474 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
477 std::tie(std::ignore, linearizedIndices) =
479 rewriter, loc, srcBits, dstBits,
480 stridedMetadata.getConstifiedMixedOffset(),
481 stridedMetadata.getConstifiedMixedSizes(),
482 stridedMetadata.getConstifiedMixedStrides(),
485 auto numElements = (origElements + scale - 1) / scale;
488 auto newRead = rewriter.
create<vector::TransferReadOp>(
489 loc, newReadType, adaptor.getSource(),
494 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newRead);
496 rewriter.
replaceOp(op, bitCast->getResult(0));
510 struct SourceElementRange {
512 int64_t sourceElementIdx;
514 int64_t sourceBitBegin;
515 int64_t sourceBitEnd;
518 struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
524 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
526 for (int64_t i = 0; i < shuffleIdx; ++i)
527 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
546 struct BitCastBitsEnumerator {
547 BitCastBitsEnumerator(VectorType sourceVectorType,
548 VectorType targetVectorType);
550 int64_t getMaxNumberOfEntries() {
551 int64_t numVectors = 0;
552 for (
const auto &l : sourceElementRanges)
553 numVectors =
std::max(numVectors, (int64_t)l.size());
557 VectorType sourceVectorType;
558 VectorType targetVectorType;
633 struct BitCastRewriter {
640 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
644 VectorType preconditionType,
Operation *op);
648 precomputeMetadata(IntegerType shuffledElementType);
654 const BitCastRewriter::Metadata &metadata);
659 BitCastBitsEnumerator enumerator;
664 [[maybe_unused]]
static raw_ostream &
666 for (
const auto &l : vec) {
668 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
669 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
670 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
677 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
678 VectorType targetVectorType)
679 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
681 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
682 "requires -D non-scalable vector type");
683 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
684 "requires -D non-scalable vector type");
685 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
686 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
687 LDBG(
"sourceVectorType: " << sourceVectorType);
689 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
690 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
691 LDBG(
"targetVectorType: " << targetVectorType);
693 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
694 (void)mostMinorSourceDim;
695 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
696 "source and target bitwidths must match");
700 for (int64_t resultBit = 0; resultBit < bitwidth;) {
701 int64_t resultElement = resultBit / targetBitWidth;
702 int64_t resultBitInElement = resultBit % targetBitWidth;
703 int64_t sourceElementIdx = resultBit / sourceBitWidth;
704 int64_t sourceBitInElement = resultBit % sourceBitWidth;
705 int64_t step =
std::min(sourceBitWidth - sourceBitInElement,
706 targetBitWidth - resultBitInElement);
707 sourceElementRanges[resultElement].push_back(
708 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
713 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
714 VectorType targetVectorType)
715 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
716 LDBG(
"\n" << enumerator.sourceElementRanges);
722 VectorType preconditionType,
724 if (!preconditionType || preconditionType.isScalable())
729 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
730 if (bitwidth % 8 != 0)
736 LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
737 VectorType preconditionType,
739 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
742 if (!preconditionType || preconditionType.getRank() != 1)
758 if (!srcType || !dstType)
760 unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
761 unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
764 if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
765 (dstElemBitwidth % srcElemBitwidth) != 0)
768 if ((srcType.getShape().back() % 2) != 0)
770 op,
"Not an even number of i4 elements in trailing dim");
776 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
778 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
779 shuffleIdx < e; ++shuffleIdx) {
784 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
785 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
786 ? srcEltRangeList[shuffleIdx].sourceElementIdx
788 shuffles.push_back(sourceElement);
790 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
791 ? srcEltRangeList[shuffleIdx].sourceBitBegin
793 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
794 ? srcEltRangeList[shuffleIdx].sourceBitEnd
798 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
800 masks.push_back(mask);
802 int64_t shiftRight = bitLo;
803 shiftRightAmounts.push_back(
806 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
807 shiftLeftAmounts.push_back(
811 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
816 Value BitCastRewriter::genericRewriteStep(
818 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
820 auto shuffleOp = rewriter.
create<vector::ShuffleOp>(
821 loc, initialValue, initialValue, metadata.shuffles);
824 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
825 auto constOp = rewriter.
create<arith::ConstantOp>(
827 Value andValue = rewriter.
create<arith::AndIOp>(loc, shuffleOp, constOp);
830 auto shiftRightConstantOp = rewriter.
create<arith::ConstantOp>(
834 rewriter.
create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
837 auto shiftLeftConstantOp = rewriter.
create<arith::ConstantOp>(
841 rewriter.
create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
845 ? rewriter.
create<arith::OrIOp>(loc, runningResult, shiftedLeft)
848 return runningResult;
856 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
857 assert(srcVecType.getElementType().isSignlessInteger(4) &&
862 constexpr int64_t i4Toi8BitwidthFactor = 2;
863 i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
865 Value i8Vector = rewriter.
create<vector::BitCastOp>(loc, i8VecType, srcValue);
869 constexpr int8_t bitsToShift = 4;
870 auto shiftValues = rewriter.
create<arith::ConstantOp>(
872 Value shl = rewriter.
create<arith::ShLIOp>(loc, i8Vector, shiftValues);
873 Value low = rewriter.
create<arith::ShRSIOp>(loc, shl, shiftValues);
874 Value high = rewriter.
create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
877 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
885 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
886 assert(srcVecType.getElementType().isSignlessInteger(4) &&
891 constexpr int64_t i4Toi8BitwidthFactor = 2;
892 i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
894 Value i8Vector = rewriter.
create<vector::BitCastOp>(loc, i8VecType, srcValue);
898 constexpr uint8_t lowBitsMask = 15;
899 auto lowBitsMaskValues = rewriter.
create<arith::ConstantOp>(
901 Value low = rewriter.
create<arith::AndIOp>(loc, i8VecType, i8Vector,
903 constexpr int8_t highBitsToShift = 4;
904 auto highShiftValues = rewriter.
create<arith::ConstantOp>(
906 Value high = rewriter.
create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
909 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
917 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
918 assert(srcVecType.getElementType().isSignlessInteger(8) &&
922 auto deinterleaveOp = rewriter.
create<vector::DeinterleaveOp>(loc, srcValue);
925 constexpr int8_t i8LowBitMask = 0x0F;
926 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
927 Value zeroOutMask = rewriter.
create<arith::ConstantOp>(
930 loc, deinterleaveOp.getRes1(), zeroOutMask);
933 constexpr int8_t bitsToShift = 4;
934 auto shiftValues = rewriter.
create<arith::ConstantOp>(
936 Value shlHigh = rewriter.
create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
940 auto mergedHiLowOp = rewriter.
create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
943 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
944 return rewriter.
create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
958 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
963 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
964 VectorType targetVectorType = bitCastOp.getResultVectorType();
965 BitCastRewriter bcr(sourceVectorType, targetVectorType);
966 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
970 Value truncValue = truncOp.getIn();
971 auto shuffledElementType =
974 for (
const BitCastRewriter ::Metadata &metadata :
975 bcr.precomputeMetadata(shuffledElementType)) {
976 runningResult = bcr.genericRewriteStep(
977 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
981 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
982 shuffledElementType.getIntOrFloatBitWidth();
984 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
985 rewriter.
replaceOp(bitCastOp, runningResult);
988 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
991 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
992 rewriter.
replaceOp(bitCastOp, runningResult);
995 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1012 template <
typename ExtOpType>
1022 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1027 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1028 VectorType targetVectorType = bitCastOp.getResultVectorType();
1029 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1030 if (failed(bcr.commonPrecondition(
1031 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1035 Value runningResult;
1036 Value sourceValue = bitCastOp.getSource();
1037 auto shuffledElementType =
1039 for (
const BitCastRewriter::Metadata &metadata :
1040 bcr.precomputeMetadata(shuffledElementType)) {
1041 runningResult = bcr.genericRewriteStep(
1042 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1047 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1048 shuffledElementType.getIntOrFloatBitWidth();
1051 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1054 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1094 template <
typename ConversionOpType,
bool isSigned>
1101 Value srcValue = conversionOp.getIn();
1102 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1103 auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1126 conversionOp, conversionOp.getType(), subByteExt);
1153 Value srcValue = truncOp.getIn();
1154 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1155 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1156 if (!srcVecType || !dstVecType)
1170 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
1172 rewriter.
create<arith::TruncIOp>(loc, i8VecType, srcValue);
1178 rewriter.
replaceOp(truncOp, subByteTrunc);
1203 constexpr
unsigned minNativeBitwidth = 8;
1204 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1205 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1206 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1208 "not a sub-byte transpose");
1212 Location loc = transposeOp.getLoc();
1217 auto srcNativeVecType = srcSubByteVecType.cloneWith(
1219 Value extOp = rewriter.
create<arith::ExtSIOp>(loc, srcNativeVecType,
1220 transposeOp.getVector());
1221 Value newTranspose = rewriter.
create<vector::TransposeOp>(
1222 loc, extOp, transposeOp.getPermutation());
1223 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1241 patterns.
add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1242 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1248 patterns.
add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1249 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.
getContext(),
1254 patterns.
add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
1255 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
1256 RewriteAlignedSubByteIntTrunc>(patterns.
getContext(),
1258 patterns.
add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>>(
1264 patterns.
add<RewriteVectorTranspose>(patterns.
getContext(), benefit);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops that take advantage of ...
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 signed extension into a sequence of shuffles and bitwise ops that take advantage...
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int origElements, int scale)
Returns a compressed mask.
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and bitwise ops that take advanta...
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType srcType, VectorType dstType, Operation *op)
Verify that source and destination element types meet the precondition for the supported aligned conv...
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...