23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
30 #define DEBUG_TYPE "vector-narrow-type-emulation"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define DBGSNL() (llvm::dbgs() << "\n")
33 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
45 int origElements,
int scale) {
46 auto numElements = (origElements + scale - 1) / scale;
51 while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
52 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
53 maskOp = extractOp.getVector().getDefiningOp();
54 extractOps.push_back(extractOp);
57 auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
58 auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
59 if (!createMaskOp && !constantMaskOp)
67 shape.back() = numElements;
71 size_t numMaskOperands = maskOperands.size();
81 newMaskOperands.push_back(
83 newMask = rewriter.
create<vector::CreateMaskOp>(loc, newMaskType,
85 }
else if (constantMaskOp) {
87 constantMaskOp.getMaskDimSizes().getValue();
88 size_t numMaskOperands = maskDimSizes.size();
90 cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
91 IntegerAttr maskIndexAttr =
94 newMaskDimSizes.push_back(maskIndexAttr);
95 newMask = rewriter.
create<vector::ConstantMaskOp>(
96 loc, newMaskType, rewriter.
getArrayAttr(newMaskDimSizes));
99 while (!extractOps.empty()) {
100 newMask = rewriter.
create<vector::ExtractOp>(
101 loc, newMask->
getResults()[0], extractOps.back().getMixedPosition());
102 extractOps.pop_back();
118 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
122 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
123 Type oldElementType = op.getValueToStore().getType().getElementType();
124 Type newElementType = convertedType.getElementType();
128 if (dstBits % srcBits != 0) {
130 op,
"only dstBits % srcBits == 0 supported");
132 int scale = dstBits / srcBits;
147 auto origElements = op.getValueToStore().getType().getNumElements();
148 if (origElements % scale != 0)
151 auto stridedMetadata =
152 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
155 std::tie(std::ignore, linearizedIndices) =
157 rewriter, loc, srcBits, dstBits,
158 stridedMetadata.getConstifiedMixedOffset(),
159 stridedMetadata.getConstifiedMixedSizes(),
160 stridedMetadata.getConstifiedMixedStrides(),
163 auto numElements = origElements / scale;
164 auto bitCast = rewriter.
create<vector::BitCastOp>(
166 op.getValueToStore());
169 op, bitCast.
getResult(), adaptor.getBase(),
179 struct ConvertVectorMaskedStore final
184 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
188 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
189 Type oldElementType = op.getValueToStore().getType().getElementType();
190 Type newElementType = convertedType.getElementType();
194 if (dstBits % srcBits != 0) {
196 op,
"only dstBits % srcBits == 0 supported");
199 int scale = dstBits / srcBits;
200 int origElements = op.getValueToStore().getType().getNumElements();
201 if (origElements % scale != 0)
204 auto stridedMetadata =
205 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
207 std::tie(std::ignore, linearizedIndicesOfr) =
209 rewriter, loc, srcBits, dstBits,
210 stridedMetadata.getConstifiedMixedOffset(),
211 stridedMetadata.getConstifiedMixedSizes(),
212 stridedMetadata.getConstifiedMixedStrides(),
214 Value linearizedIndices =
240 auto numElements = (origElements + scale - 1) / scale;
242 auto passThru = rewriter.
create<arith::ConstantOp>(
245 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
246 loc, newType, adaptor.getBase(), linearizedIndices,
247 newMask.value()->getResult(0), passThru);
249 Value valueToStore = rewriter.
create<vector::BitCastOp>(
250 loc, op.getValueToStore().getType(), newLoad);
251 valueToStore = rewriter.
create<arith::SelectOp>(
252 loc, op.getMask(), op.getValueToStore(), valueToStore);
254 rewriter.
create<vector::BitCastOp>(loc, newType, valueToStore);
257 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
271 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
275 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
276 Type oldElementType = op.getType().getElementType();
277 Type newElementType = convertedType.getElementType();
281 if (dstBits % srcBits != 0) {
283 op,
"only dstBits % srcBits == 0 supported");
285 int scale = dstBits / srcBits;
304 auto origElements = op.getVectorType().getNumElements();
305 if (origElements % scale != 0)
308 auto stridedMetadata =
309 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
312 std::tie(std::ignore, linearizedIndices) =
314 rewriter, loc, srcBits, dstBits,
315 stridedMetadata.getConstifiedMixedOffset(),
316 stridedMetadata.getConstifiedMixedSizes(),
317 stridedMetadata.getConstifiedMixedStrides(),
320 auto numElements = (origElements + scale - 1) / scale;
321 auto newLoad = rewriter.
create<vector::LoadOp>(
326 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newLoad);
328 rewriter.
replaceOp(op, bitCast->getResult(0));
337 struct ConvertVectorMaskedLoad final
342 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
346 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
347 Type oldElementType = op.getType().getElementType();
348 Type newElementType = convertedType.getElementType();
352 if (dstBits % srcBits != 0) {
354 op,
"only dstBits % srcBits == 0 supported");
356 int scale = dstBits / srcBits;
400 auto origType = op.getVectorType();
401 auto origElements = origType.getNumElements();
402 if (origElements % scale != 0)
405 auto stridedMetadata =
406 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
408 std::tie(std::ignore, linearizedIndices) =
410 rewriter, loc, srcBits, dstBits,
411 stridedMetadata.getConstifiedMixedOffset(),
412 stridedMetadata.getConstifiedMixedSizes(),
413 stridedMetadata.getConstifiedMixedStrides(),
421 auto numElements = (origElements + scale - 1) / scale;
424 rewriter.
create<vector::BitCastOp>(loc, newType, op.getPassThru());
427 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
428 loc, newType, adaptor.getBase(),
430 newMask.value()->getResult(0), newPassThru);
435 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newLoad);
436 auto select = rewriter.
create<arith::SelectOp>(loc, op.getMask(), bitCast,
438 rewriter.
replaceOp(op, select->getResult(0));
448 struct ConvertVectorTransferRead final
453 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
457 auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
458 Type oldElementType = op.getType().getElementType();
459 Type newElementType = convertedType.getElementType();
463 if (dstBits % srcBits != 0) {
465 op,
"only dstBits % srcBits == 0 supported");
467 int scale = dstBits / srcBits;
469 auto origElements = op.getVectorType().getNumElements();
470 if (origElements % scale != 0)
473 auto newPadding = rewriter.
create<arith::ExtUIOp>(loc, newElementType,
474 adaptor.getPadding());
476 auto stridedMetadata =
477 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
480 std::tie(std::ignore, linearizedIndices) =
482 rewriter, loc, srcBits, dstBits,
483 stridedMetadata.getConstifiedMixedOffset(),
484 stridedMetadata.getConstifiedMixedSizes(),
485 stridedMetadata.getConstifiedMixedStrides(),
488 auto numElements = (origElements + scale - 1) / scale;
491 auto newRead = rewriter.
create<vector::TransferReadOp>(
492 loc, newReadType, adaptor.getSource(),
497 rewriter.
create<vector::BitCastOp>(loc, op.getType(), newRead);
499 rewriter.
replaceOp(op, bitCast->getResult(0));
513 struct SourceElementRange {
515 int64_t sourceElementIdx;
517 int64_t sourceBitBegin;
518 int64_t sourceBitEnd;
521 struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
527 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
529 for (int64_t i = 0; i < shuffleIdx; ++i)
530 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
549 struct BitCastBitsEnumerator {
550 BitCastBitsEnumerator(VectorType sourceVectorType,
551 VectorType targetVectorType);
553 int64_t getMaxNumberOfEntries() {
554 int64_t numVectors = 0;
555 for (
const auto &l : sourceElementRanges)
556 numVectors =
std::max(numVectors, (int64_t)l.size());
560 VectorType sourceVectorType;
561 VectorType targetVectorType;
636 struct BitCastRewriter {
643 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
647 VectorType preconditionType,
Operation *op);
651 precomputeMetadata(IntegerType shuffledElementType);
657 const BitCastRewriter::Metadata &metadata);
662 BitCastBitsEnumerator enumerator;
667 [[maybe_unused]]
static raw_ostream &
669 for (
const auto &l : vec) {
671 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
672 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
673 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
680 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
681 VectorType targetVectorType)
682 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
684 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
685 "requires -D non-scalable vector type");
686 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
687 "requires -D non-scalable vector type");
688 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
689 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
690 LDBG(
"sourceVectorType: " << sourceVectorType);
692 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
693 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
694 LDBG(
"targetVectorType: " << targetVectorType);
696 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
697 (void)mostMinorSourceDim;
698 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
699 "source and target bitwidths must match");
703 for (int64_t resultBit = 0; resultBit < bitwidth;) {
704 int64_t resultElement = resultBit / targetBitWidth;
705 int64_t resultBitInElement = resultBit % targetBitWidth;
706 int64_t sourceElementIdx = resultBit / sourceBitWidth;
707 int64_t sourceBitInElement = resultBit % sourceBitWidth;
708 int64_t step =
std::min(sourceBitWidth - sourceBitInElement,
709 targetBitWidth - resultBitInElement);
710 sourceElementRanges[resultElement].push_back(
711 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
716 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
717 VectorType targetVectorType)
718 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
719 LDBG(
"\n" << enumerator.sourceElementRanges);
725 VectorType preconditionType,
727 if (!preconditionType || preconditionType.isScalable())
732 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
733 if (bitwidth % 8 != 0)
740 VectorType preconditionType,
742 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
745 if (!preconditionType || preconditionType.getRank() != 1)
761 if (!srcType || !dstType)
763 unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
764 unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
767 if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
768 (dstElemBitwidth % srcElemBitwidth) != 0)
771 if ((srcType.getShape().back() % 2) != 0)
773 op,
"Not an even number of i4 elements in trailing dim");
779 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
781 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
782 shuffleIdx < e; ++shuffleIdx) {
787 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
788 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
789 ? srcEltRangeList[shuffleIdx].sourceElementIdx
791 shuffles.push_back(sourceElement);
793 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
794 ? srcEltRangeList[shuffleIdx].sourceBitBegin
796 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
797 ? srcEltRangeList[shuffleIdx].sourceBitEnd
801 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
803 masks.push_back(mask);
805 int64_t shiftRight = bitLo;
806 shiftRightAmounts.push_back(
809 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
810 shiftLeftAmounts.push_back(
814 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
819 Value BitCastRewriter::genericRewriteStep(
821 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
823 auto shuffleOp = rewriter.
create<vector::ShuffleOp>(
824 loc, initialValue, initialValue, metadata.shuffles);
827 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
828 auto constOp = rewriter.
create<arith::ConstantOp>(
830 Value andValue = rewriter.
create<arith::AndIOp>(loc, shuffleOp, constOp);
833 auto shiftRightConstantOp = rewriter.
create<arith::ConstantOp>(
837 rewriter.
create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
840 auto shiftLeftConstantOp = rewriter.
create<arith::ConstantOp>(
844 rewriter.
create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
848 ? rewriter.
create<arith::OrIOp>(loc, runningResult, shiftedLeft)
851 return runningResult;
859 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
860 assert(srcVecType.getElementType().isSignlessInteger(4) &&
865 constexpr int64_t i4Toi8BitwidthFactor = 2;
866 i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
868 Value i8Vector = rewriter.
create<vector::BitCastOp>(loc, i8VecType, srcValue);
872 constexpr int8_t bitsToShift = 4;
873 auto shiftValues = rewriter.
create<arith::ConstantOp>(
875 Value shl = rewriter.
create<arith::ShLIOp>(loc, i8Vector, shiftValues);
876 Value low = rewriter.
create<arith::ShRSIOp>(loc, shl, shiftValues);
877 Value high = rewriter.
create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
880 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
888 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
889 assert(srcVecType.getElementType().isSignlessInteger(8) &&
893 int64_t vecDimSize = srcVecType.getShape().back();
896 assert((vecDimSize % 2) == 0 &&
"Odd number of i4 elements");
897 deinterleaveLowMaskValues.reserve(vecDimSize / 2);
898 deinterleaveHighMaskValues.reserve(vecDimSize / 2);
899 for (
int i = 0, end = vecDimSize; i < end; i += 2) {
900 deinterleaveLowMaskValues.push_back(i);
901 deinterleaveHighMaskValues.push_back(i + 1);
904 auto lowShuffleOp = rewriter.
create<vector::ShuffleOp>(
905 loc, srcValue, srcValue,
907 auto highShuffleOp = rewriter.
create<vector::ShuffleOp>(
908 loc, srcValue, srcValue,
912 constexpr int8_t i8LowBitMask = 0x0F;
913 Value zeroOutMask = rewriter.
create<arith::ConstantOp>(
917 rewriter.
create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
920 constexpr int8_t bitsToShift = 4;
921 VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
922 auto shiftValues = rewriter.
create<arith::ConstantOp>(
925 rewriter.
create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
928 auto mergedHiLowOp = rewriter.
create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
931 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
932 return rewriter.
create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
946 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
951 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
952 VectorType targetVectorType = bitCastOp.getResultVectorType();
953 BitCastRewriter bcr(sourceVectorType, targetVectorType);
954 if (
failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
958 Value truncValue = truncOp.getIn();
959 auto shuffledElementType =
962 for (
const BitCastRewriter ::Metadata &metadata :
963 bcr.precomputeMetadata(shuffledElementType)) {
964 runningResult = bcr.genericRewriteStep(
965 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
969 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
970 shuffledElementType.getIntOrFloatBitWidth();
972 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
973 rewriter.
replaceOp(bitCastOp, runningResult);
976 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
979 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
980 rewriter.
replaceOp(bitCastOp, runningResult);
983 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1000 template <
typename ExtOpType>
1010 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1015 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1016 VectorType targetVectorType = bitCastOp.getResultVectorType();
1017 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1018 if (
failed(bcr.commonPrecondition(
1019 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1023 Value runningResult;
1024 Value sourceValue = bitCastOp.getSource();
1025 auto shuffledElementType =
1027 for (
const BitCastRewriter::Metadata &metadata :
1028 bcr.precomputeMetadata(shuffledElementType)) {
1029 runningResult = bcr.genericRewriteStep(
1030 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1035 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1036 shuffledElementType.getIntOrFloatBitWidth();
1039 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1042 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1072 template <
typename ConversionOpType>
1073 struct RewriteAlignedSubByteIntSignedExt :
OpRewritePattern<ConversionOpType> {
1079 Value srcValue = conversionOp.getIn();
1080 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1081 auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1097 conversionOp, conversionOp.getType(), subByteExt);
1126 Value srcValue = truncOp.getIn();
1127 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1128 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1129 if (!srcVecType || !dstVecType)
1134 if (srcVecType.getRank() != 1)
1148 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
1150 rewriter.
create<arith::TruncIOp>(loc, i8VecType, srcValue);
1156 rewriter.
replaceOp(truncOp, subByteTrunc);
1181 constexpr
unsigned minNativeBitwidth = 8;
1182 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1183 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1184 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1186 "not a sub-byte transpose");
1190 Location loc = transposeOp.getLoc();
1195 auto srcNativeVecType = srcSubByteVecType.cloneWith(
1197 Value extOp = rewriter.
create<arith::ExtSIOp>(loc, srcNativeVecType,
1198 transposeOp.getVector());
1199 Value newTranspose = rewriter.
create<vector::TransposeOp>(
1200 loc, extOp, transposeOp.getPermutation());
1201 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1219 patterns.
add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1220 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1226 patterns.
add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1227 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.
getContext(),
1232 patterns.
add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1233 RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
1234 RewriteAlignedSubByteIntTrunc>(patterns.
getContext(),
1240 patterns.
add<RewriteVectorTranspose>(patterns.
getContext(), benefit);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops that take advantage of hi...
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 signed extension into a sequence of shuffles and bitwise ops that take advantage...
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int origElements, int scale)
Returns a compressed mask.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType srcType, VectorType dstType, Operation *op)
Verify that source and destination element types meet the precondition for the supported aligned conv...
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...