30 #define DEBUG_TYPE "arm-sme-vector-legalization"
33 #define GEN_PASS_DEF_VECTORLEGALIZATION
34 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
47 static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
48 "op vector size is not multiple of SME tiles");
49 static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
50 "op mask is unsupported for legalization/decomposition");
51 static constexpr StringLiteral
52 kMatchFailureNonPermutationMap(
"op affine map is not a permutation");
53 static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
54 "expected transpose from illegal type to legal type");
84 auto vscale = builder.
create<vector::VectorScaleOp>(loc);
85 return llvm::map_to_vector(
86 llvm::zip_equal(indices, scalableOffsets), [&](
auto pair) ->
Value {
87 auto [index, base] = pair;
88 auto offset = builder.
create<arith::MulIOp>(
89 loc, builder.
create<arith::ConstantIndexOp>(loc, base), vscale);
90 return builder.
create<arith::AddIOp>(loc, index, offset);
110 SMESubTile smeTile) {
111 return addConstantScalableOffset(builder, loc, indices,
112 {smeTile.row, smeTile.col});
118 bool isSupportedMaskOp(
Value mask) {
124 SMESubTile smeTile) {
125 assert(isSupportedMaskOp(mask));
132 auto smeTileMaskDims = addConstantScalableOffset(
133 builder, loc,
createMask.getOperands(), {-smeTile.row, -smeTile.col});
134 auto smeTileCreateMask = builder.
create<vector::CreateMaskOp>(
135 loc, smeTile.type.clone(builder.
getI1Type()), smeTileMaskDims);
143 auto decomposeToSMETiles(
OpBuilder &builder, VectorType type,
144 VectorType smeTileType,
145 bool transposeIndices =
false) {
146 return llvm::map_range(
149 {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
150 std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
152 int row = int(indices[0]);
153 int col = int(indices[1]);
154 if (transposeIndices)
156 return SMESubTile{row, col, smeTileType};
162 int getNumberOfSMETilesForVectorType(VectorType type) {
164 "`type` not multiple of SME tiles");
165 int64_t vectorRows = type.getDimSize(0);
166 int64_t vectorCols = type.getDimSize(1);
167 auto elementType = type.getElementType();
169 return (vectorRows * vectorCols) / (minNumElts * minNumElts);
174 struct LegalizeArithConstantOpsByDecomposition
179 matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
181 auto vectorType = dyn_cast<VectorType>(constantOp.getType());
182 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
183 if (!vectorType || !denseAttr || !denseAttr.isSplat())
188 kMatchFailureNotSMETileTypeMultiple);
191 auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
192 auto tileSplat = rewriter.
create<arith::ConstantOp>(
193 constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
195 adaptor.getResultMapping());
203 struct LegalizeVectorOuterProductOpsByDecomposition
208 matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
210 auto vectorType = outerProductOp.getResultVectorType();
213 kMatchFailureNotSMETileTypeMultiple);
217 auto loc = outerProductOp.
getLoc();
218 if (outerProductOp.isMasked()) {
219 auto maskOp = outerProductOp.getMaskingOp();
220 mask = maskOp.getMask();
224 if (!isSupportedMaskOp(mask))
226 kMatchFailureUnsupportedMaskOp);
234 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
236 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
237 auto lhs = rewriter.
create<vector::ScalableExtractOp>(
238 loc, sliceType, outerProductOp.getLhs(), smeTile.row);
239 auto rhs = rewriter.
create<vector::ScalableExtractOp>(
240 loc, sliceType, outerProductOp.getRhs(), smeTile.col);
241 auto smeOuterProduct = rewriter.
create<vector::OuterProductOp>(
242 loc, smeTileType, lhs, rhs,
243 !accSMETiles.empty() ? accSMETiles[index] :
Value{},
244 outerProductOp.getKind());
246 auto maskedOuterProduct =
248 resultSMETiles.push_back(maskedOuterProduct->getResult(0));
251 rewriter.
replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
261 struct LegalizeMaskedVectorOuterProductOpsByDecomposition
266 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
268 if (
auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
269 maskOp.getMaskableOp())) {
270 LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
273 outerProductOp, rewriter);
281 struct LegalizeTransferReadOpsByDecomposition
286 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
288 auto vectorType = readOp.getVectorType();
291 kMatchFailureNotSMETileTypeMultiple);
293 auto mask = readOp.getMask();
294 if (!isSupportedMaskOp(mask))
296 kMatchFailureUnsupportedMaskOp);
298 auto permutationMap = readOp.getPermutationMap();
299 if (!permutationMap.isPermutation())
301 kMatchFailureNonPermutationMap);
305 bool transposed = !permutationMap.isIdentity();
307 auto loc = readOp.getLoc();
311 for (SMESubTile smeTile :
312 decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
313 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
314 auto smeRead = rewriter.
create<vector::TransferReadOp>(
315 loc, smeTileType, readOp.getSource(),
316 getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
317 readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
318 readOp.getInBoundsAttr());
319 resultSMETiles.push_back(smeRead);
322 rewriter.
replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
329 struct LegalizeTransferWriteOpsByDecomposition
334 matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
336 auto vectorType = writeOp.getVectorType();
339 kMatchFailureNotSMETileTypeMultiple);
341 auto mask = writeOp.getMask();
342 if (!isSupportedMaskOp(mask))
344 kMatchFailureUnsupportedMaskOp);
346 auto permutationMap = writeOp.getPermutationMap();
347 if (!permutationMap.isPermutation())
349 kMatchFailureNonPermutationMap);
353 bool transposed = !permutationMap.isIdentity();
355 auto loc = writeOp.getLoc();
357 auto inputSMETiles = adaptor.getVector();
359 Value destTensorOrMemref = writeOp.getSource();
361 rewriter, vectorType, smeTileType, transposed))) {
362 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
363 auto smeWrite = rewriter.
create<vector::TransferWriteOp>(
364 loc, inputSMETiles[index], destTensorOrMemref,
365 getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
366 writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
367 if (writeOp.hasPureTensorSemantics())
368 destTensorOrMemref = smeWrite.
getResult();
371 if (writeOp.hasPureTensorSemantics())
372 rewriter.
replaceOp(writeOp, destTensorOrMemref);
411 struct LegalizeMultiTileTransferWriteAsStoreLoop
416 matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
418 if (writeOp.hasPureTensorSemantics())
420 writeOp,
"TODO: tensor semantics are unsupported");
422 auto permutationMap = writeOp.getPermutationMap();
423 if (!permutationMap.isPermutation())
425 kMatchFailureNonPermutationMap);
427 bool transposed = !permutationMap.isIdentity();
430 "TODO: transpose unsupported");
432 auto vectorType = writeOp.getVectorType();
435 kMatchFailureNotSMETileTypeMultiple);
439 auto mask = writeOp.getMask();
440 if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
441 vectorType.getDimSize(1) > 16)))
443 kMatchFailureUnsupportedMaskOp);
445 auto loc = writeOp.getLoc();
446 auto createVscaleMultiple =
451 auto minTileSlices = smeTileType.getDimSize(0);
452 VectorType sliceMaskType =
456 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
457 auto upperBound = createVscaleMultiple(minTileSlices);
458 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
460 rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
464 auto inputSMETiles = adaptor.getVector();
465 auto tileSliceIndex = storeLoop.getInductionVar();
467 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
469 auto tileRow = createVscaleMultiple(smeTile.row);
470 auto tileCol = createVscaleMultiple(smeTile.col);
474 rewriter.
create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
477 auto storeRow = rewriter.
create<arith::AddIOp>(loc, sliceIndex,
478 writeOp.getIndices()[0]);
480 rewriter.
create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
483 Value sliceMask =
nullptr;
485 sliceMask = rewriter.
create<vector::ExtractOp>(
487 if (sliceMaskType != sliceMask.getType())
488 sliceMask = rewriter.
create<vector::ScalableExtractOp>(
489 loc, sliceMaskType, sliceMask, smeTile.col);
495 rewriter.
create<vector::ExtractOp>(loc,
tile, tileSliceIndex);
496 rewriter.
create<vector::TransferWriteOp>(
497 loc, slice, writeOp.getSource(),
ValueRange{storeRow, storeCol},
532 struct FoldExtractFromVectorOfSMELikeCreateMasks
536 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
538 auto loc = extractOp.getLoc();
540 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
543 extractOp,
"extract not from vector.create_mask op");
545 VectorType extractedMaskType =
546 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
547 if (!extractedMaskType)
549 "extracted type is not a vector type");
551 auto numScalable = extractedMaskType.getNumScalableDims();
552 if (numScalable != 2)
554 extractOp,
"expected extracted type to be an SME-like mask");
557 if (extractOp.getStaticPosition().size() != 1)
559 extractOp,
"only a single extraction index is supported");
561 auto frontMaskDim = createMaskOp.getOperand(0);
562 if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
565 "constant vector.create_masks dims should be folded elsewhere");
567 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
569 rewriter, loc, extractOp.getMixedPosition()[0]);
570 auto extractionInTrueRegion = rewriter.
create<arith::CmpIOp>(
571 loc, rewriter.
getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
573 auto newMaskFrontDim = rewriter.
create<arith::SelectOp>(
574 loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
577 extractOp, extractedMaskType,
578 ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
584 bool isLegalVectorType(VectorType vType) {
585 bool seenFixedDim =
false;
586 for (
bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
587 seenFixedDim |= !scalableFlag;
588 if (seenFixedDim && scalableFlag)
622 struct LiftIllegalVectorTransposeToMemory
627 if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
632 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
634 auto sourceType = transposeOp.getSourceVectorType();
635 auto resultType = transposeOp.getResultVectorType();
636 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
638 kMatchFailureNotIllegalToLegal);
641 Value maybeRead = transposeOp.getVector();
644 if (
Value extendSource = getExtensionSource(transposeSourceOp)) {
645 maybeRead = extendSource;
646 extendOp = transposeSourceOp;
649 auto illegalRead = maybeRead.
getDefiningOp<vector::TransferReadOp>();
653 "expected source to be (possibly extended) transfer_read");
655 if (!illegalRead.getPermutationMap().isIdentity())
657 illegalRead,
"expected read to have identity permutation map");
659 auto loc = transposeOp.getLoc();
660 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
661 auto one = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
664 auto readType = illegalRead.getVectorType();
665 auto readSizes = llvm::map_to_vector(
666 llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
667 [&](
auto dim) ->
Value {
668 auto [size, isScalable] = dim;
669 auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
672 auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
673 return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
676 auto readSubview = rewriter.
create<memref::SubViewOp>(
677 loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
682 Value mask = illegalRead.getMask();
686 mask = rewriter.
create<vector::TransposeOp>(loc, mask,
687 transposeOp.getPermutation());
692 auto transposedSubview = rewriter.
create<memref::TransposeOp>(
694 ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
703 VectorType legalReadType = resultType.clone(readType.getElementType());
706 auto legalRead = rewriter.
create<vector::TransferReadOp>(
707 loc, legalReadType, transposedSubview, readIndices,
708 illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
716 Value(legalRead), resultType);
744 struct ConvertIllegalShapeCastOpsToTransposes
748 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
750 auto sourceType = shapeCastOp.getSourceVectorType();
751 auto resultType = shapeCastOp.getResultVectorType();
752 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
754 kMatchFailureNotIllegalToLegal);
758 if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
760 shapeCastOp,
"expected source to be a 2D scalable vector with a "
761 "trailing unit dim");
763 auto loc = shapeCastOp.getLoc();
767 if (resultType.getRank() == 1)
806 struct LowerIllegalTransposeStoreViaZA
810 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
812 if (!isSupportedMaskOp(writeOp.getMask()))
814 kMatchFailureUnsupportedMaskOp);
816 auto permutationMap = writeOp.getPermutationMap();
817 if (!permutationMap.isIdentity())
819 kMatchFailureNonPermutationMap);
821 auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
825 auto sourceType = transposeOp.getSourceVectorType();
826 auto resultType = transposeOp.getResultVectorType();
828 if (resultType.getRank() != 2)
831 if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
833 transposeOp,
"not illegal/unsupported SVE transpose");
838 if (sourceType.getDimSize(0) <= 1 ||
839 sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
842 auto loc = writeOp.getLoc();
843 auto createVscaleMultiple =
850 Value undefTile = rewriter.
create<arm_sme::GetTileOp>(loc, smeTileType);
851 Value destTensorOrMemref = writeOp.getSource();
852 auto numSlicesPerTile =
853 std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
855 rewriter.
create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
857 decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
863 for (
int d = 0; d < numSlicesPerTile; ++d) {
865 loc, transposeOp.getVector(),
867 if (vector.
getType() != smeSliceType) {
868 vector = rewriter.
create<vector::ScalableExtractOp>(
869 loc, smeSliceType, vector, smeTile.col);
875 auto transposedRow = createVscaleMultiple(smeTile.col);
877 rewriter.
create<arith::ConstantIndexOp>(loc, smeTile.row);
882 if (
auto mask = writeOp.getMask()) {
888 maskCols = rewriter.
create<index::MinSOp>(loc, maskCols, numSlices);
890 maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
891 maskCols = numSlices;
893 auto subMask = rewriter.
create<vector::CreateMaskOp>(
894 loc, smeTileType.clone(rewriter.
getI1Type()),
898 auto writeIndices = writeOp.getIndices();
900 rewriter.
create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
902 rewriter.
create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
903 auto smeWrite = rewriter.
create<vector::TransferWriteOp>(
905 transposeMap, subMask, writeOp.getInBounds());
907 if (writeOp.hasPureTensorSemantics())
908 destTensorOrMemref = smeWrite.
getResult();
911 if (writeOp.hasPureTensorSemantics())
912 rewriter.
replaceOp(writeOp, destTensorOrMemref);
920 struct VectorLegalizationPass
921 :
public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
922 void runOnOperation()
override {
928 [](VectorType vectorType,
932 auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
939 patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
940 LiftIllegalVectorTransposeToMemory,
941 ConvertIllegalShapeCastOpsToTransposes,
942 LowerIllegalTransposeStoreViaZA>(context);
946 patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
947 LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
949 patterns.add<LegalizeArithConstantOpsByDecomposition,
950 LegalizeVectorOuterProductOpsByDecomposition,
951 LegalizeTransferReadOpsByDecomposition,
952 LegalizeTransferWriteOpsByDecomposition>(converter, context);
958 return signalPassFailure();
965 return std::make_unique<VectorLegalizationPass>();
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value createMask(AffineForOp vecForOp, VectorizationState &state)
Creates a mask used to filter out garbage elements in the last iteration of unaligned loops.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getIndexAttr(int64_t value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class is a wrapper around OneToNConversionPattern for matching against instances of a particular...
OneToNOpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Specialization of PatternRewriter that OneToNConversionPatterns use.
void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping)
Replaces the results of the operation with the specified list of values mapped back to the original t...
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperationName getName()
The name of an operation is the key identifier for it.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
std::unique_ptr< Pass > createVectorLegalizationPass()
Pass that legalizes vectors so they can be lowered to ArmSME.
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFStructuralOneToNTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the provided pattern set with patterns that do 1:N type conversions on (some) SCF ops.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...