31 #define DEBUG_TYPE "arm-sme-vector-legalization"
34 #define GEN_PASS_DEF_VECTORLEGALIZATION
35 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
48 static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
49 "op vector size is not multiple of SME tiles");
50 static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
51 "op mask is unsupported for legalization/decomposition");
52 static constexpr StringLiteral
53 kMatchFailureNonPermutationMap(
"op affine map is not a permutation");
54 static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
55 "expected transpose from illegal type to legal type");
85 auto vscale = builder.
create<vector::VectorScaleOp>(loc);
86 return llvm::map_to_vector(
87 llvm::zip_equal(indices, scalableOffsets), [&](
auto pair) ->
Value {
88 auto [index, base] = pair;
89 auto offset = builder.
create<arith::MulIOp>(
90 loc, builder.
create<arith::ConstantIndexOp>(loc, base), vscale);
91 return builder.
create<arith::AddIOp>(loc, index, offset);
111 SMESubTile smeTile) {
112 return addConstantScalableOffset(builder, loc, indices,
113 {smeTile.row, smeTile.col});
119 bool isSupportedMaskOp(
Value mask) {
125 SMESubTile smeTile) {
126 assert(isSupportedMaskOp(mask));
133 auto smeTileMaskDims = addConstantScalableOffset(
134 builder, loc,
createMask.getOperands(), {-smeTile.row, -smeTile.col});
135 auto smeTileCreateMask = builder.
create<vector::CreateMaskOp>(
136 loc, smeTile.type.clone(builder.
getI1Type()), smeTileMaskDims);
144 auto decomposeToSMETiles(
OpBuilder &builder, VectorType type,
145 VectorType smeTileType,
146 bool transposeIndices =
false) {
147 return llvm::map_range(
150 {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
151 std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
153 int row = int(indices[0]);
154 int col = int(indices[1]);
155 if (transposeIndices)
157 return SMESubTile{row, col, smeTileType};
163 int getNumberOfSMETilesForVectorType(VectorType type) {
165 "`type` not multiple of SME tiles");
166 int64_t vectorRows = type.getDimSize(0);
167 int64_t vectorCols = type.getDimSize(1);
168 auto elementType = type.getElementType();
170 return (vectorRows * vectorCols) / (minNumElts * minNumElts);
175 struct LegalizeArithConstantOpsByDecomposition
180 matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
182 auto vectorType = dyn_cast<VectorType>(constantOp.getType());
183 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
184 if (!vectorType || !denseAttr || !denseAttr.isSplat())
189 kMatchFailureNotSMETileTypeMultiple);
192 auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
193 auto tileSplat = rewriter.
create<arith::ConstantOp>(
194 constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
204 struct LegalizeVectorOuterProductOpsByDecomposition
209 matchAndRewrite(vector::OuterProductOp outerProductOp,
210 OneToNOpAdaptor adaptor,
212 auto vectorType = outerProductOp.getResultVectorType();
215 kMatchFailureNotSMETileTypeMultiple);
219 auto loc = outerProductOp.
getLoc();
220 if (outerProductOp.isMasked()) {
221 auto maskOp = outerProductOp.getMaskingOp();
222 mask = maskOp.getMask();
227 if (!isSupportedMaskOp(mask))
229 kMatchFailureUnsupportedMaskOp);
237 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
239 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
240 auto lhs = rewriter.
create<vector::ScalableExtractOp>(
241 loc, sliceType, outerProductOp.getLhs(), smeTile.row);
242 auto rhs = rewriter.
create<vector::ScalableExtractOp>(
243 loc, sliceType, outerProductOp.getRhs(), smeTile.col);
244 auto smeOuterProduct = rewriter.
create<vector::OuterProductOp>(
245 loc, smeTileType, lhs, rhs,
246 !accSMETiles.empty() ? accSMETiles[index] :
Value{},
247 outerProductOp.getKind());
249 auto maskedOuterProduct =
251 resultSMETiles.push_back(maskedOuterProduct->getResult(0));
264 struct LegalizeMaskedVectorOuterProductOpsByDecomposition
269 matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
271 if (
auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
272 maskOp.getMaskableOp())) {
273 LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
276 outerProductOp, rewriter);
284 struct LegalizeTransferReadOpsByDecomposition
289 matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
291 auto vectorType = readOp.getVectorType();
294 kMatchFailureNotSMETileTypeMultiple);
296 auto mask = readOp.getMask();
297 if (!isSupportedMaskOp(mask))
299 kMatchFailureUnsupportedMaskOp);
301 auto permutationMap = readOp.getPermutationMap();
302 if (!permutationMap.isPermutation())
304 kMatchFailureNonPermutationMap);
308 bool transposed = !permutationMap.isIdentity();
310 auto loc = readOp.getLoc();
314 for (SMESubTile smeTile :
315 decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
316 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
317 auto smeRead = rewriter.
create<vector::TransferReadOp>(
318 loc, smeTileType, readOp.getSource(),
319 getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
320 readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
321 readOp.getInBoundsAttr());
322 resultSMETiles.push_back(smeRead);
332 struct LegalizeTransferWriteOpsByDecomposition
337 matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
339 auto vectorType = writeOp.getVectorType();
342 kMatchFailureNotSMETileTypeMultiple);
344 auto mask = writeOp.getMask();
345 if (!isSupportedMaskOp(mask))
347 kMatchFailureUnsupportedMaskOp);
349 auto permutationMap = writeOp.getPermutationMap();
350 if (!permutationMap.isPermutation())
352 kMatchFailureNonPermutationMap);
356 bool transposed = !permutationMap.isIdentity();
358 auto loc = writeOp.getLoc();
360 auto inputSMETiles = adaptor.getVector();
362 Value destTensorOrMemref = writeOp.getSource();
364 rewriter, vectorType, smeTileType, transposed))) {
365 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
366 auto smeWrite = rewriter.
create<vector::TransferWriteOp>(
367 loc, inputSMETiles[index], destTensorOrMemref,
368 getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
369 writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
370 if (writeOp.hasPureTensorSemantics())
371 destTensorOrMemref = smeWrite.
getResult();
374 if (writeOp.hasPureTensorSemantics())
375 rewriter.
replaceOp(writeOp, destTensorOrMemref);
414 struct LegalizeMultiTileTransferWriteAsStoreLoop
419 matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
421 if (writeOp.hasPureTensorSemantics())
423 writeOp,
"TODO: tensor semantics are unsupported");
425 auto permutationMap = writeOp.getPermutationMap();
426 if (!permutationMap.isPermutation())
428 kMatchFailureNonPermutationMap);
430 bool transposed = !permutationMap.isIdentity();
433 "TODO: transpose unsupported");
435 auto vectorType = writeOp.getVectorType();
438 kMatchFailureNotSMETileTypeMultiple);
442 auto mask = writeOp.getMask();
443 if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
444 vectorType.getDimSize(1) > 16)))
446 kMatchFailureUnsupportedMaskOp);
448 auto loc = writeOp.getLoc();
449 auto createVscaleMultiple =
454 auto minTileSlices = smeTileType.getDimSize(0);
455 VectorType sliceMaskType =
459 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
460 auto upperBound = createVscaleMultiple(minTileSlices);
461 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
463 rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
467 auto inputSMETiles = adaptor.getVector();
468 auto tileSliceIndex = storeLoop.getInductionVar();
470 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
472 auto tileRow = createVscaleMultiple(smeTile.row);
473 auto tileCol = createVscaleMultiple(smeTile.col);
477 rewriter.
create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
480 auto storeRow = rewriter.
create<arith::AddIOp>(loc, sliceIndex,
481 writeOp.getIndices()[0]);
483 rewriter.
create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
486 Value sliceMask =
nullptr;
488 sliceMask = rewriter.
create<vector::ExtractOp>(
490 if (sliceMaskType != sliceMask.getType())
491 sliceMask = rewriter.
create<vector::ScalableExtractOp>(
492 loc, sliceMaskType, sliceMask, smeTile.col);
498 rewriter.
create<vector::ExtractOp>(loc,
tile, tileSliceIndex);
499 rewriter.
create<vector::TransferWriteOp>(
500 loc, slice, writeOp.getSource(),
ValueRange{storeRow, storeCol},
535 struct FoldExtractFromVectorOfSMELikeCreateMasks
539 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
541 auto loc = extractOp.getLoc();
543 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
546 extractOp,
"extract not from vector.create_mask op");
548 VectorType extractedMaskType =
549 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
550 if (!extractedMaskType)
552 "extracted type is not a vector type");
554 auto numScalable = extractedMaskType.getNumScalableDims();
555 if (numScalable != 2)
557 extractOp,
"expected extracted type to be an SME-like mask");
560 if (extractOp.getStaticPosition().size() != 1)
562 extractOp,
"only a single extraction index is supported");
564 auto frontMaskDim = createMaskOp.getOperand(0);
565 if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
568 "constant vector.create_masks dims should be folded elsewhere");
570 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
572 rewriter, loc, extractOp.getMixedPosition()[0]);
573 auto extractionInTrueRegion = rewriter.
create<arith::CmpIOp>(
574 loc, rewriter.
getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
576 auto newMaskFrontDim = rewriter.
create<arith::SelectOp>(
577 loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
580 extractOp, extractedMaskType,
581 ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
587 bool isLegalVectorType(VectorType vType) {
588 bool seenFixedDim =
false;
589 for (
bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
590 seenFixedDim |= !scalableFlag;
591 if (seenFixedDim && scalableFlag)
625 struct LiftIllegalVectorTransposeToMemory
630 if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
635 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
637 auto sourceType = transposeOp.getSourceVectorType();
638 auto resultType = transposeOp.getResultVectorType();
639 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
641 kMatchFailureNotIllegalToLegal);
644 Value maybeRead = transposeOp.getVector();
647 if (
Value extendSource = getExtensionSource(transposeSourceOp)) {
648 maybeRead = extendSource;
649 extendOp = transposeSourceOp;
652 auto illegalRead = maybeRead.
getDefiningOp<vector::TransferReadOp>();
656 "expected source to be (possibly extended) transfer_read");
658 if (!illegalRead.getPermutationMap().isIdentity())
660 illegalRead,
"expected read to have identity permutation map");
662 auto loc = transposeOp.getLoc();
663 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
664 auto one = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
667 auto readType = illegalRead.getVectorType();
668 auto readSizes = llvm::map_to_vector(
669 llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
670 [&](
auto dim) ->
Value {
671 auto [size, isScalable] = dim;
672 auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
675 auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
676 return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
679 auto readSubview = rewriter.
create<memref::SubViewOp>(
680 loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
685 Value mask = illegalRead.getMask();
689 mask = rewriter.
create<vector::TransposeOp>(loc, mask,
690 transposeOp.getPermutation());
695 auto transposedSubview = rewriter.
create<memref::TransposeOp>(
697 ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
706 VectorType legalReadType = resultType.clone(readType.getElementType());
709 auto legalRead = rewriter.
create<vector::TransferReadOp>(
710 loc, legalReadType, transposedSubview, readIndices,
711 illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
719 Value(legalRead), resultType);
747 struct ConvertIllegalShapeCastOpsToTransposes
751 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
753 auto sourceType = shapeCastOp.getSourceVectorType();
754 auto resultType = shapeCastOp.getResultVectorType();
755 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
757 kMatchFailureNotIllegalToLegal);
761 if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
763 shapeCastOp,
"expected source to be a 2D scalable vector with a "
764 "trailing unit dim");
766 auto loc = shapeCastOp.getLoc();
770 if (resultType.getRank() == 1)
809 struct LowerIllegalTransposeStoreViaZA
813 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
815 if (!isSupportedMaskOp(writeOp.getMask()))
817 kMatchFailureUnsupportedMaskOp);
819 auto permutationMap = writeOp.getPermutationMap();
820 if (!permutationMap.isIdentity())
822 kMatchFailureNonPermutationMap);
824 auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
828 auto sourceType = transposeOp.getSourceVectorType();
829 auto resultType = transposeOp.getResultVectorType();
831 if (resultType.getRank() != 2)
834 if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
836 transposeOp,
"not illegal/unsupported SVE transpose");
841 if (sourceType.getDimSize(0) <= 1 ||
842 sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
845 auto loc = writeOp.getLoc();
846 auto createVscaleMultiple =
853 Value undefTile = rewriter.
create<arm_sme::GetTileOp>(loc, smeTileType);
854 Value destTensorOrMemref = writeOp.getSource();
855 auto numSlicesPerTile =
856 std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
858 rewriter.
create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
860 decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
866 for (
int d = 0; d < numSlicesPerTile; ++d) {
868 loc, transposeOp.getVector(),
870 if (vector.
getType() != smeSliceType) {
871 vector = rewriter.
create<vector::ScalableExtractOp>(
872 loc, smeSliceType, vector, smeTile.col);
878 auto transposedRow = createVscaleMultiple(smeTile.col);
880 rewriter.
create<arith::ConstantIndexOp>(loc, smeTile.row);
885 if (
auto mask = writeOp.getMask()) {
891 maskCols = rewriter.
create<index::MinSOp>(loc, maskCols, numSlices);
893 maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
894 maskCols = numSlices;
896 auto subMask = rewriter.
create<vector::CreateMaskOp>(
897 loc, smeTileType.clone(rewriter.
getI1Type()),
901 auto writeIndices = writeOp.getIndices();
903 rewriter.
create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
905 rewriter.
create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
906 auto smeWrite = rewriter.
create<vector::TransferWriteOp>(
908 transposeMap, subMask, writeOp.getInBounds());
910 if (writeOp.hasPureTensorSemantics())
911 destTensorOrMemref = smeWrite.
getResult();
914 if (writeOp.hasPureTensorSemantics())
915 rewriter.
replaceOp(writeOp, destTensorOrMemref);
923 struct VectorLegalizationPass
924 :
public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
925 void runOnOperation()
override {
931 [](VectorType vectorType,
935 auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
944 rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
945 LiftIllegalVectorTransposeToMemory,
946 ConvertIllegalShapeCastOpsToTransposes,
947 LowerIllegalTransposeStoreViaZA>(context);
950 return signalPassFailure();
955 patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
956 LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
958 patterns.add<LegalizeArithConstantOpsByDecomposition,
959 LegalizeVectorOuterProductOpsByDecomposition,
960 LegalizeTransferReadOpsByDecomposition,
961 LegalizeTransferWriteOpsByDecomposition>(converter, context);
962 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
969 target.markUnknownOpDynamicallyLegal(
970 [&](
Operation *op) {
return converter.isLegal(op); });
971 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
972 return converter.isSignatureLegal(op.getFunctionType());
976 return signalPassFailure();
983 return std::make_unique<VectorLegalizationPass>();
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value createMask(AffineForOp vecForOp, VectorizationState &state)
Creates a mask used to filter out garbage elements in the last iteration of unaligned loops.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getIndexAttr(int64_t value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperationName getName()
The name of an operation is the key identifier for it.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
std::unique_ptr< Pass > createVectorLegalizationPass()
Pass that legalizes vectors so they can be lowered to ArmSME.
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...