31 #define DEBUG_TYPE "arm-sme-vector-legalization"
34 #define GEN_PASS_DEF_VECTORLEGALIZATION
35 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
48 static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
49 "op vector size is not multiple of SME tiles");
50 static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
51 "op mask is unsupported for legalization/decomposition");
52 static constexpr StringLiteral
53 kMatchFailureNonPermutationMap(
"op affine map is not a permutation");
54 static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
55 "expected transpose from illegal type to legal type");
85 auto vscale = vector::VectorScaleOp::create(builder, loc);
86 return llvm::map_to_vector(
87 llvm::zip_equal(indices, scalableOffsets), [&](
auto pair) ->
Value {
88 auto [index, base] = pair;
89 auto offset = arith::MulIOp::create(
92 return arith::AddIOp::create(builder, loc, index, offset);
112 SMESubTile smeTile) {
113 return addConstantScalableOffset(builder, loc, indices,
114 {smeTile.row, smeTile.col});
120 bool isSupportedMaskOp(
Value mask) {
126 SMESubTile smeTile) {
127 assert(isSupportedMaskOp(mask));
134 auto smeTileMaskDims = addConstantScalableOffset(
135 builder, loc,
createMask.getOperands(), {-smeTile.row, -smeTile.col});
136 auto smeTileCreateMask = vector::CreateMaskOp::create(
137 builder, loc, smeTile.type.clone(builder.
getI1Type()), smeTileMaskDims);
138 return smeTileCreateMask.getResult();
145 auto decomposeToSMETiles(
OpBuilder &builder, VectorType type,
146 VectorType smeTileType,
147 bool transposeIndices =
false) {
148 return llvm::map_range(
151 {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
152 std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
154 int row = int(indices[0]);
155 int col = int(indices[1]);
156 if (transposeIndices)
158 return SMESubTile{row, col, smeTileType};
164 int getNumberOfSMETilesForVectorType(VectorType type) {
166 "`type` not multiple of SME tiles");
167 int64_t vectorRows = type.getDimSize(0);
168 int64_t vectorCols = type.getDimSize(1);
169 auto elementType = type.getElementType();
171 return (vectorRows * vectorCols) / (minNumElts * minNumElts);
176 struct LegalizeArithConstantOpsByDecomposition
181 matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
183 auto vectorType = dyn_cast<VectorType>(constantOp.getType());
184 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
185 if (!vectorType || !denseAttr || !denseAttr.isSplat())
190 kMatchFailureNotSMETileTypeMultiple);
193 auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
194 auto tileSplat = arith::ConstantOp::create(
195 rewriter, constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
205 struct LegalizeVectorOuterProductOpsByDecomposition
210 matchAndRewrite(vector::OuterProductOp outerProductOp,
211 OneToNOpAdaptor adaptor,
213 auto vectorType = outerProductOp.getResultVectorType();
216 kMatchFailureNotSMETileTypeMultiple);
220 auto loc = outerProductOp.getLoc();
221 if (outerProductOp.isMasked()) {
222 auto maskOp = outerProductOp.getMaskingOp();
223 mask = maskOp.getMask();
228 if (!isSupportedMaskOp(mask))
230 kMatchFailureUnsupportedMaskOp);
238 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
240 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
241 auto lhs = vector::ScalableExtractOp::create(
242 rewriter, loc, sliceType, outerProductOp.getLhs(), smeTile.row);
243 auto rhs = vector::ScalableExtractOp::create(
244 rewriter, loc, sliceType, outerProductOp.getRhs(), smeTile.col);
245 auto smeOuterProduct = vector::OuterProductOp::create(
246 rewriter, loc, smeTileType, lhs, rhs,
247 !accSMETiles.empty() ? accSMETiles[index] :
Value{},
248 outerProductOp.getKind());
250 auto maskedOuterProduct =
252 resultSMETiles.push_back(maskedOuterProduct->getResult(0));
265 struct LegalizeMaskedVectorOuterProductOpsByDecomposition
270 matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
272 if (
auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
273 maskOp.getMaskableOp())) {
274 LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
277 outerProductOp, rewriter);
285 struct LegalizeTransferReadOpsByDecomposition
290 matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
292 auto vectorType = readOp.getVectorType();
295 kMatchFailureNotSMETileTypeMultiple);
297 auto mask = readOp.getMask();
298 if (!isSupportedMaskOp(mask))
300 kMatchFailureUnsupportedMaskOp);
302 auto permutationMap = readOp.getPermutationMap();
303 if (!permutationMap.isPermutation())
305 kMatchFailureNonPermutationMap);
309 bool transposed = !permutationMap.isIdentity();
311 auto loc = readOp.getLoc();
315 for (SMESubTile smeTile :
316 decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
317 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
318 auto smeRead = vector::TransferReadOp::create(
319 rewriter, loc, smeTileType, readOp.getBase(),
320 getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
321 readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
322 readOp.getInBoundsAttr());
323 resultSMETiles.push_back(smeRead);
333 struct LegalizeTransferWriteOpsByDecomposition
338 matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
340 auto vectorType = writeOp.getVectorType();
343 kMatchFailureNotSMETileTypeMultiple);
345 auto mask = writeOp.getMask();
346 if (!isSupportedMaskOp(mask))
348 kMatchFailureUnsupportedMaskOp);
350 auto permutationMap = writeOp.getPermutationMap();
351 if (!permutationMap.isPermutation())
353 kMatchFailureNonPermutationMap);
357 bool transposed = !permutationMap.isIdentity();
359 auto loc = writeOp.getLoc();
361 auto inputSMETiles = adaptor.getValueToStore();
363 Value destTensorOrMemref = writeOp.getBase();
365 rewriter, vectorType, smeTileType, transposed))) {
366 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
367 auto smeWrite = vector::TransferWriteOp::create(
368 rewriter, loc, inputSMETiles[index], destTensorOrMemref,
369 getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
370 writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
371 if (writeOp.hasPureTensorSemantics())
372 destTensorOrMemref = smeWrite.getResult();
375 if (writeOp.hasPureTensorSemantics())
376 rewriter.
replaceOp(writeOp, destTensorOrMemref);
415 struct LegalizeMultiTileTransferWriteAsStoreLoop
420 matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
422 if (writeOp.hasPureTensorSemantics())
424 writeOp,
"TODO: tensor semantics are unsupported");
426 auto permutationMap = writeOp.getPermutationMap();
427 if (!permutationMap.isPermutation())
429 kMatchFailureNonPermutationMap);
431 bool transposed = !permutationMap.isIdentity();
434 "TODO: transpose unsupported");
436 auto vectorType = writeOp.getVectorType();
439 kMatchFailureNotSMETileTypeMultiple);
443 auto mask = writeOp.getMask();
444 if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
445 vectorType.getDimSize(1) > 16)))
447 kMatchFailureUnsupportedMaskOp);
449 auto loc = writeOp.getLoc();
450 auto createVscaleMultiple =
455 auto minTileSlices = smeTileType.getDimSize(0);
456 VectorType sliceMaskType =
461 auto upperBound = createVscaleMultiple(minTileSlices);
464 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
468 auto inputSMETiles = adaptor.getValueToStore();
469 auto tileSliceIndex = storeLoop.getInductionVar();
471 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
473 auto tileRow = createVscaleMultiple(smeTile.row);
474 auto tileCol = createVscaleMultiple(smeTile.col);
478 arith::AddIOp::create(rewriter, loc, tileRow, tileSliceIndex);
481 auto storeRow = arith::AddIOp::create(rewriter, loc, sliceIndex,
482 writeOp.getIndices()[0]);
483 auto storeCol = arith::AddIOp::create(rewriter, loc, tileCol,
484 writeOp.getIndices()[1]);
487 Value sliceMask =
nullptr;
489 sliceMask = vector::ExtractOp::create(rewriter, loc, mask,
491 if (sliceMaskType != sliceMask.
getType())
492 sliceMask = vector::ScalableExtractOp::create(
493 rewriter, loc, sliceMaskType, sliceMask, smeTile.col);
499 vector::ExtractOp::create(rewriter, loc,
tile, tileSliceIndex);
500 vector::TransferWriteOp::create(
501 rewriter, loc, slice, writeOp.getBase(),
537 struct FoldExtractFromVectorOfSMELikeCreateMasks
541 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
543 auto loc = extractOp.getLoc();
545 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
548 extractOp,
"extract not from vector.create_mask op");
550 VectorType extractedMaskType =
551 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
552 if (!extractedMaskType)
554 "extracted type is not a vector type");
556 auto numScalable = extractedMaskType.getNumScalableDims();
557 if (numScalable != 2)
559 extractOp,
"expected extracted type to be an SME-like mask");
562 if (extractOp.getStaticPosition().size() != 1)
564 extractOp,
"only a single extraction index is supported");
566 auto frontMaskDim = createMaskOp.getOperand(0);
567 if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
570 "constant vector.create_masks dims should be folded elsewhere");
574 rewriter, loc, extractOp.getMixedPosition()[0]);
575 auto extractionInTrueRegion = arith::CmpIOp::create(
576 rewriter, loc, rewriter.
getI1Type(), arith::CmpIPredicate::slt,
577 extractionIndex, frontMaskDim);
578 auto newMaskFrontDim =
579 arith::SelectOp::create(rewriter, loc, extractionInTrueRegion,
580 createMaskOp.getOperand(1), zero);
583 extractOp, extractedMaskType,
584 ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
590 bool isLegalVectorType(VectorType vType) {
591 bool seenFixedDim =
false;
592 for (
bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
593 seenFixedDim |= !scalableFlag;
594 if (seenFixedDim && scalableFlag)
628 struct LiftIllegalVectorTransposeToMemory
633 if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
638 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
640 auto sourceType = transposeOp.getSourceVectorType();
641 auto resultType = transposeOp.getResultVectorType();
642 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
644 kMatchFailureNotIllegalToLegal);
647 Value maybeRead = transposeOp.getVector();
650 if (
Value extendSource = getExtensionSource(transposeSourceOp)) {
651 maybeRead = extendSource;
652 extendOp = transposeSourceOp;
655 auto illegalRead = maybeRead.
getDefiningOp<vector::TransferReadOp>();
659 "expected source to be (possibly extended) transfer_read");
661 if (!illegalRead.getPermutationMap().isIdentity())
663 illegalRead,
"expected read to have identity permutation map");
665 auto loc = transposeOp.getLoc();
670 auto readType = illegalRead.getVectorType();
671 auto readSizes = llvm::map_to_vector(
672 llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
673 [&](
auto dim) ->
Value {
674 auto [size, isScalable] = dim;
675 auto dimSize = arith::ConstantIndexOp::create(rewriter, loc, size);
678 auto vscale = vector::VectorScaleOp::create(rewriter, loc);
679 return arith::MulIOp::create(rewriter, loc, vscale, dimSize);
683 memref::SubViewOp::create(rewriter, loc, illegalRead.getBase(),
684 illegalRead.getIndices(), readSizes, strides);
688 Value mask = illegalRead.getMask();
692 mask = vector::TransposeOp::create(rewriter, loc, mask,
693 transposeOp.getPermutation());
698 auto transposedSubview = memref::TransposeOp::create(
700 ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
709 VectorType legalReadType = resultType.clone(readType.getElementType());
712 auto legalRead = vector::TransferReadOp::create(
713 rewriter, loc, legalReadType, transposedSubview, readIndices,
714 illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
722 Value(legalRead), resultType);
759 struct LowerIllegalTransposeStoreViaZA
763 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
765 if (!isSupportedMaskOp(writeOp.getMask()))
767 kMatchFailureUnsupportedMaskOp);
769 auto permutationMap = writeOp.getPermutationMap();
770 if (!permutationMap.isIdentity())
772 kMatchFailureNonPermutationMap);
774 auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
778 auto sourceType = transposeOp.getSourceVectorType();
779 auto resultType = transposeOp.getResultVectorType();
781 if (resultType.getRank() != 2)
784 if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
786 transposeOp,
"not illegal/unsupported SVE transpose");
791 if (sourceType.getDimSize(0) <= 1 ||
792 sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
795 auto loc = writeOp.getLoc();
796 auto createVscaleMultiple =
803 Value undefTile = arm_sme::GetTileOp::create(rewriter, loc, smeTileType);
804 Value destTensorOrMemref = writeOp.getBase();
805 auto numSlicesPerTile =
806 std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
810 decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
816 for (
int d = 0; d < numSlicesPerTile; ++d) {
818 vector::ExtractOp::create(rewriter, loc, transposeOp.getVector(),
820 if (vector.
getType() != smeSliceType) {
821 vector = vector::ScalableExtractOp::create(
822 rewriter, loc, smeSliceType, vector, smeTile.col);
824 tile = vector::InsertOp::create(rewriter, loc, vector,
tile, d);
828 auto transposedRow = createVscaleMultiple(smeTile.col);
835 if (
auto mask = writeOp.getMask()) {
837 maskRows = arith::SubIOp::create(
838 rewriter, loc,
createMask.getOperand(0), transposedRow);
839 maskCols = arith::SubIOp::create(
840 rewriter, loc,
createMask.getOperand(1), transposedCol);
841 maskCols = index::MinSOp::create(rewriter, loc, maskCols, numSlices);
843 maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
844 maskCols = numSlices;
846 auto subMask = vector::CreateMaskOp::create(
847 rewriter, loc, smeTileType.clone(rewriter.
getI1Type()),
851 auto writeIndices = writeOp.getIndices();
853 arith::AddIOp::create(rewriter, loc, transposedRow, writeIndices[0]);
855 arith::AddIOp::create(rewriter, loc, transposedCol, writeIndices[1]);
856 auto smeWrite = vector::TransferWriteOp::create(
857 rewriter, loc,
tile, destTensorOrMemref,
ValueRange{destRow, destCol},
858 transposeMap, subMask, writeOp.getInBounds());
860 if (writeOp.hasPureTensorSemantics())
861 destTensorOrMemref = smeWrite.getResult();
864 if (writeOp.hasPureTensorSemantics())
865 rewriter.
replaceOp(writeOp, destTensorOrMemref);
904 struct LowerColumnTransferReadToLoops
908 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
912 if (readOp.hasPureTensorSemantics())
914 readOp,
"Tensor semantics are unsupported (either bufferize or "
915 "extend this pattern)");
917 auto resType = readOp.getVectorType();
919 if (resType.getRank() != 2)
921 "Only 2D vectors are supported!");
923 if (resType.getShape()[1] != 1)
925 readOp,
"The trailing output dim is != 1 (not supported ATM)");
927 if (!resType.getScalableDims()[0] || resType.getScalableDims()[1])
929 readOp,
"Expected the leading dim to be scalable and the trailing "
934 int64_t numRows = resType.getShape()[0];
935 VectorType newResType =
VectorType::get(numRows, resType.getElementType(),
939 auto loc = readOp.getLoc();
941 auto createVscaleMultiple =
943 auto upperBound = createVscaleMultiple(numRows);
945 Value init = arith::ConstantOp::create(
951 loadLoop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
955 auto tileSliceIndex = loadLoop.getInductionVar();
957 auto idx0 = arith::AddIOp::create(rewriter, loc, tileSliceIndex,
958 readOp.getIndices()[0]);
959 auto idx1 = readOp.getIndices()[1];
961 Value scalar = memref::LoadOp::create(rewriter, loc, readOp.getBase(),
964 Operation *updateInit = vector::InsertOp::create(
965 rewriter, loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex);
967 scf::YieldOp::create(rewriter, loc, updateInit->
getResult(0));
974 auto sc = vector::ShapeCastOp::create(
975 rewriter, loc, readOp.getResult().getType(), loadLoop.getResult(0));
983 struct VectorLegalizationPass
984 :
public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
985 void runOnOperation()
override {
991 [](VectorType vectorType,
995 auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
1005 .add<FoldExtractFromVectorOfSMELikeCreateMasks,
1006 LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory,
1007 LowerIllegalTransposeStoreViaZA>(context);
1010 return signalPassFailure();
1015 patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
1016 LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
1018 patterns.add<LegalizeArithConstantOpsByDecomposition,
1019 LegalizeVectorOuterProductOpsByDecomposition,
1020 LegalizeTransferReadOpsByDecomposition,
1021 LegalizeTransferWriteOpsByDecomposition>(converter, context);
1022 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
1029 target.markUnknownOpDynamicallyLegal(
1030 [&](
Operation *op) {
return converter.isLegal(op); });
1031 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1032 return converter.isSignatureLegal(op.getFunctionType());
1036 return signalPassFailure();
1043 return std::make_unique<VectorLegalizationPass>();
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value createMask(AffineForOp vecForOp, VectorizationState &state)
Creates a mask used to filter out garbage elements in the last iteration of unaligned loops.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getIndexAttr(int64_t value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
OperationName getName()
The name of an operation is the key identifier for it.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
std::unique_ptr< Pass > createVectorLegalizationPass()
Pass that legalizes vectors so they can be lowered to ArmSME.
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...