31#define DEBUG_TYPE "arm-sme-vector-legalization"
34#define GEN_PASS_DEF_VECTORLEGALIZATION
35#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
48static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
49 "op vector size is not multiple of SME tiles");
50static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
51 "op mask is unsupported for legalization/decomposition");
52static constexpr StringLiteral
53 kMatchFailureNonPermutationMap(
"op affine map is not a permutation");
54static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
55 "expected transpose from illegal type to legal type");
85 auto vscale = vector::VectorScaleOp::create(builder, loc);
86 return llvm::map_to_vector(
87 llvm::zip_equal(
indices, scalableOffsets), [&](
auto pair) ->
Value {
88 auto [
index, base] = pair;
89 auto offset = arith::MulIOp::create(
92 return arith::AddIOp::create(builder, loc,
index, offset);
112 SMESubTile smeTile) {
113 return addConstantScalableOffset(builder, loc,
indices,
114 {smeTile.row, smeTile.col});
120bool isSupportedMaskOp(
Value mask) {
126 SMESubTile smeTile) {
127 assert(isSupportedMaskOp(mask));
134 auto smeTileMaskDims = addConstantScalableOffset(
135 builder, loc,
createMask.getOperands(), {-smeTile.row, -smeTile.col});
136 auto smeTileCreateMask = vector::CreateMaskOp::create(
137 builder, loc, smeTile.type.clone(builder.
getI1Type()), smeTileMaskDims);
138 return smeTileCreateMask.getResult();
145auto decomposeToSMETiles(
OpBuilder &builder, VectorType type,
146 VectorType smeTileType,
147 bool transposeIndices =
false) {
148 return llvm::map_range(
151 {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
152 std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
154 int row = int(indices[0]);
155 int col = int(indices[1]);
156 if (transposeIndices)
158 return SMESubTile{row, col, smeTileType};
164int getNumberOfSMETilesForVectorType(VectorType type) {
166 "`type` not multiple of SME tiles");
167 int64_t vectorRows = type.getDimSize(0);
168 int64_t vectorCols = type.getDimSize(1);
169 auto elementType = type.getElementType();
171 return (vectorRows * vectorCols) / (minNumElts * minNumElts);
176struct LegalizeArithConstantOpsByDecomposition
177 :
public OpConversionPattern<arith::ConstantOp> {
178 using OpConversionPattern::OpConversionPattern;
181 matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
182 ConversionPatternRewriter &rewriter)
const override {
183 auto vectorType = dyn_cast<VectorType>(constantOp.getType());
184 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
185 if (!vectorType || !denseAttr || !denseAttr.isSplat())
189 return rewriter.notifyMatchFailure(constantOp,
190 kMatchFailureNotSMETileTypeMultiple);
193 auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
194 auto tileSplat = arith::ConstantOp::create(
195 rewriter, constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
196 SmallVector<Value> repl(tileCount, tileSplat);
197 rewriter.replaceOpWithMultiple(constantOp, {repl});
205struct LegalizeVectorOuterProductOpsByDecomposition
206 :
public OpConversionPattern<vector::OuterProductOp> {
207 using OpConversionPattern::OpConversionPattern;
210 matchAndRewrite(vector::OuterProductOp outerProductOp,
211 OneToNOpAdaptor adaptor,
212 ConversionPatternRewriter &rewriter)
const override {
213 auto vectorType = outerProductOp.getResultVectorType();
215 return rewriter.notifyMatchFailure(outerProductOp,
216 kMatchFailureNotSMETileTypeMultiple);
219 Operation *rootOp = outerProductOp;
220 auto loc = outerProductOp.getLoc();
221 if (outerProductOp.isMasked()) {
222 auto maskOp = outerProductOp.getMaskingOp();
223 mask = maskOp.getMask();
225 rewriter.setInsertionPoint(rootOp);
228 if (!isSupportedMaskOp(mask))
229 return rewriter.notifyMatchFailure(outerProductOp,
230 kMatchFailureUnsupportedMaskOp);
234 VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
236 SmallVector<Value> resultSMETiles;
237 for (
auto [index, smeTile] : llvm::enumerate(
238 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
240 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
241 auto lhs = vector::ScalableExtractOp::create(
242 rewriter, loc, sliceType, outerProductOp.getLhs(), smeTile.row);
243 auto rhs = vector::ScalableExtractOp::create(
244 rewriter, loc, sliceType, outerProductOp.getRhs(), smeTile.col);
245 auto smeOuterProduct = vector::OuterProductOp::create(
246 rewriter, loc, smeTileType,
lhs,
rhs,
247 !accSMETiles.empty() ? accSMETiles[index] : Value{},
248 outerProductOp.getKind());
250 auto *maskedOuterProduct =
252 resultSMETiles.push_back(maskedOuterProduct->getResult(0));
255 rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
265struct LegalizeMaskedVectorOuterProductOpsByDecomposition
266 :
public OpConversionPattern<vector::MaskOp> {
267 using OpConversionPattern::OpConversionPattern;
270 matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter)
const override {
272 if (
auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
273 maskOp.getMaskableOp())) {
274 LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
276 return static_cast<RewritePattern &
>(pattern).matchAndRewrite(
277 outerProductOp, rewriter);
285struct LegalizeTransferReadOpsByDecomposition
286 :
public OpConversionPattern<vector::TransferReadOp> {
287 using OpConversionPattern::OpConversionPattern;
290 matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
291 ConversionPatternRewriter &rewriter)
const override {
292 auto vectorType = readOp.getVectorType();
294 return rewriter.notifyMatchFailure(readOp,
295 kMatchFailureNotSMETileTypeMultiple);
297 auto mask = readOp.getMask();
298 if (!isSupportedMaskOp(mask))
299 return rewriter.notifyMatchFailure(readOp,
300 kMatchFailureUnsupportedMaskOp);
302 auto permutationMap = readOp.getPermutationMap();
303 if (!permutationMap.isPermutation())
304 return rewriter.notifyMatchFailure(readOp,
305 kMatchFailureNonPermutationMap);
309 bool transposed = !permutationMap.isIdentity();
311 auto loc = readOp.getLoc();
315 for (SMESubTile smeTile :
316 decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
317 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
318 auto smeRead = vector::TransferReadOp::create(
319 rewriter, loc, smeTileType, readOp.getBase(),
320 getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
321 readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
322 readOp.getInBoundsAttr());
323 resultSMETiles.push_back(smeRead);
326 rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
333struct LegalizeTransferWriteOpsByDecomposition
334 :
public OpConversionPattern<vector::TransferWriteOp> {
335 using OpConversionPattern::OpConversionPattern;
339 ConversionPatternRewriter &rewriter)
const override {
340 auto vectorType = writeOp.getVectorType();
342 return rewriter.notifyMatchFailure(writeOp,
343 kMatchFailureNotSMETileTypeMultiple);
345 auto mask = writeOp.getMask();
346 if (!isSupportedMaskOp(mask))
347 return rewriter.notifyMatchFailure(writeOp,
348 kMatchFailureUnsupportedMaskOp);
350 auto permutationMap = writeOp.getPermutationMap();
351 if (!permutationMap.isPermutation())
352 return rewriter.notifyMatchFailure(writeOp,
353 kMatchFailureNonPermutationMap);
357 bool transposed = !permutationMap.isIdentity();
359 auto loc = writeOp.getLoc();
361 auto inputSMETiles = adaptor.getValueToStore();
363 Value destTensorOrMemref = writeOp.getBase();
364 for (
auto [
index, smeTile] : llvm::enumerate(decomposeToSMETiles(
365 rewriter, vectorType, smeTileType, transposed))) {
366 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
367 auto smeWrite = vector::TransferWriteOp::create(
368 rewriter, loc, inputSMETiles[
index], destTensorOrMemref,
369 getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
370 writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
371 if (writeOp.hasPureTensorSemantics())
372 destTensorOrMemref = smeWrite.getResult();
375 if (writeOp.hasPureTensorSemantics())
376 rewriter.replaceOp(writeOp, destTensorOrMemref);
378 rewriter.eraseOp(writeOp);
415struct LegalizeMultiTileTransferWriteAsStoreLoop
416 :
public OpConversionPattern<vector::TransferWriteOp> {
417 using OpConversionPattern::OpConversionPattern;
420 matchAndRewrite(vector::TransferWriteOp writeOp,
OneToNOpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter)
const override {
422 if (writeOp.hasPureTensorSemantics())
423 return rewriter.notifyMatchFailure(
424 writeOp,
"TODO: tensor semantics are unsupported");
426 auto permutationMap = writeOp.getPermutationMap();
427 if (!permutationMap.isPermutation())
428 return rewriter.notifyMatchFailure(writeOp,
429 kMatchFailureNonPermutationMap);
431 bool transposed = !permutationMap.isIdentity();
433 return rewriter.notifyMatchFailure(writeOp,
434 "TODO: transpose unsupported");
436 auto vectorType = writeOp.getVectorType();
438 return rewriter.notifyMatchFailure(writeOp,
439 kMatchFailureNotSMETileTypeMultiple);
443 auto mask = writeOp.getMask();
444 if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
445 vectorType.getDimSize(1) > 16)))
446 return rewriter.notifyMatchFailure(writeOp,
447 kMatchFailureUnsupportedMaskOp);
449 auto loc = writeOp.getLoc();
450 auto createVscaleMultiple =
455 auto minTileSlices = smeTileType.getDimSize(0);
456 VectorType sliceMaskType =
457 VectorType::get(minTileSlices, rewriter.getI1Type(),
true);
461 auto upperBound = createVscaleMultiple(minTileSlices);
464 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
465 rewriter.setInsertionPointToStart(storeLoop.getBody());
468 auto inputSMETiles = adaptor.getValueToStore();
469 auto tileSliceIndex = storeLoop.getInductionVar();
470 for (
auto [
index, smeTile] : llvm::enumerate(
471 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
473 auto tileRow = createVscaleMultiple(smeTile.row);
474 auto tileCol = createVscaleMultiple(smeTile.col);
478 arith::AddIOp::create(rewriter, loc, tileRow, tileSliceIndex);
481 auto storeRow = arith::AddIOp::create(rewriter, loc, sliceIndex,
482 writeOp.getIndices()[0]);
483 auto storeCol = arith::AddIOp::create(rewriter, loc, tileCol,
484 writeOp.getIndices()[1]);
487 Value sliceMask =
nullptr;
489 sliceMask = vector::ExtractOp::create(rewriter, loc, mask,
491 if (sliceMaskType != sliceMask.
getType())
492 sliceMask = vector::ScalableExtractOp::create(
493 rewriter, loc, sliceMaskType, sliceMask, smeTile.col);
499 vector::ExtractOp::create(rewriter, loc,
tile, tileSliceIndex);
500 vector::TransferWriteOp::create(
501 rewriter, loc, slice, writeOp.getBase(),
503 AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
505 rewriter.getBoolArrayAttr(
509 rewriter.eraseOp(writeOp);
537struct FoldExtractFromVectorOfSMELikeCreateMasks
539 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
541 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
542 PatternRewriter &rewriter)
const override {
543 auto loc = extractOp.getLoc();
545 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
548 extractOp,
"extract not from vector.create_mask op");
550 VectorType extractedMaskType =
551 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
552 if (!extractedMaskType)
554 "extracted type is not a vector type");
556 auto numScalable = extractedMaskType.getNumScalableDims();
557 if (numScalable != 2)
559 extractOp,
"expected extracted type to be an SME-like mask");
562 if (extractOp.getStaticPosition().size() != 1)
564 extractOp,
"only a single extraction index is supported");
566 auto frontMaskDim = createMaskOp.getOperand(0);
567 if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
570 "constant vector.create_masks dims should be folded elsewhere");
574 rewriter, loc, extractOp.getMixedPosition()[0]);
575 auto extractionInTrueRegion = arith::CmpIOp::create(
576 rewriter, loc, rewriter.
getI1Type(), arith::CmpIPredicate::slt,
577 extractionIndex, frontMaskDim);
578 auto newMaskFrontDim =
579 arith::SelectOp::create(rewriter, loc, extractionInTrueRegion,
580 createMaskOp.getOperand(1), zero);
583 extractOp, extractedMaskType,
584 ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
590bool isLegalVectorType(VectorType vType) {
591 bool seenFixedDim =
false;
592 for (
bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
593 seenFixedDim |= !scalableFlag;
594 if (seenFixedDim && scalableFlag)
628struct LiftIllegalVectorTransposeToMemory
630 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
632 static Value getExtensionSource(Operation *op) {
633 if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
638 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
639 PatternRewriter &rewriter)
const override {
640 auto sourceType = transposeOp.getSourceVectorType();
641 auto resultType = transposeOp.getResultVectorType();
642 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
644 kMatchFailureNotIllegalToLegal);
647 Value maybeRead = transposeOp.getVector();
649 Operation *extendOp =
nullptr;
650 if (Value extendSource = getExtensionSource(transposeSourceOp)) {
651 maybeRead = extendSource;
652 extendOp = transposeSourceOp;
655 auto illegalRead = maybeRead.
getDefiningOp<vector::TransferReadOp>();
659 "expected source to be (possibly extended) transfer_read");
661 if (!illegalRead.getPermutationMap().isIdentity())
663 illegalRead,
"expected read to have identity permutation map");
665 auto loc = transposeOp.getLoc();
670 auto readType = illegalRead.getVectorType();
671 auto readSizes = llvm::map_to_vector(
672 llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
673 [&](
auto dim) -> Value {
674 auto [size, isScalable] = dim;
675 auto dimSize = arith::ConstantIndexOp::create(rewriter, loc, size);
678 auto vscale = vector::VectorScaleOp::create(rewriter, loc);
679 return arith::MulIOp::create(rewriter, loc, vscale, dimSize);
681 SmallVector<Value> strides(readType.getRank(), Value(one));
683 memref::SubViewOp::create(rewriter, loc, illegalRead.getBase(),
684 illegalRead.getIndices(), readSizes, strides);
688 Value mask = illegalRead.getMask();
692 mask = vector::TransposeOp::create(rewriter, loc, mask,
693 transposeOp.getPermutation());
698 auto transposedSubview = memref::TransposeOp::create(
699 rewriter, loc, readSubview, AffineMapAttr::get(transposeMap));
700 ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
703 SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
709 VectorType legalReadType = resultType.clone(readType.getElementType());
711 SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
712 auto legalRead = vector::TransferReadOp::create(
713 rewriter, loc, legalReadType, transposedSubview, readIndices,
714 illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
719 rewriter.
replaceOp(transposeOp, [&]() -> Operation * {
722 Value(legalRead), resultType);
759struct LowerIllegalTransposeStoreViaZA
763 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
764 PatternRewriter &rewriter)
const override {
765 if (!isSupportedMaskOp(writeOp.getMask()))
767 kMatchFailureUnsupportedMaskOp);
769 auto permutationMap = writeOp.getPermutationMap();
770 if (!permutationMap.isIdentity())
772 kMatchFailureNonPermutationMap);
774 auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
778 auto sourceType = transposeOp.getSourceVectorType();
779 auto resultType = transposeOp.getResultVectorType();
781 if (resultType.getRank() != 2)
784 if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
786 transposeOp,
"not illegal/unsupported SVE transpose");
789 VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
791 if (sourceType.getDimSize(0) <= 1 ||
792 sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
795 auto loc = writeOp.getLoc();
796 auto createVscaleMultiple =
799 auto transposeMap = AffineMapAttr::get(
803 Value undefTile = arm_sme::GetTileOp::create(rewriter, loc, smeTileType);
804 Value destTensorOrMemref = writeOp.getBase();
805 auto numSlicesPerTile =
806 std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
809 for (
auto [index, smeTile] : llvm::enumerate(
810 decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
815 Value
tile = undefTile;
816 for (
int d = 0; d < numSlicesPerTile; ++d) {
818 vector::ExtractOp::create(rewriter, loc, transposeOp.getVector(),
820 if (vector.
getType() != smeSliceType) {
821 vector = vector::ScalableExtractOp::create(
822 rewriter, loc, smeSliceType, vector, smeTile.col);
824 tile = vector::InsertOp::create(rewriter, loc, vector,
tile, d);
828 auto transposedRow = createVscaleMultiple(smeTile.col);
835 if (
auto mask = writeOp.getMask()) {
837 maskRows = arith::SubIOp::create(
838 rewriter, loc,
createMask.getOperand(0), transposedRow);
839 maskCols = arith::SubIOp::create(
840 rewriter, loc,
createMask.getOperand(1), transposedCol);
841 maskCols = index::MinSOp::create(rewriter, loc, maskCols, numSlices);
843 maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
844 maskCols = numSlices;
846 auto subMask = vector::CreateMaskOp::create(
847 rewriter, loc, smeTileType.clone(rewriter.
getI1Type()),
851 auto writeIndices = writeOp.getIndices();
853 arith::AddIOp::create(rewriter, loc, transposedRow, writeIndices[0]);
855 arith::AddIOp::create(rewriter, loc, transposedCol, writeIndices[1]);
856 auto smeWrite = vector::TransferWriteOp::create(
857 rewriter, loc,
tile, destTensorOrMemref,
ValueRange{destRow, destCol},
858 transposeMap, subMask, writeOp.getInBounds());
860 if (writeOp.hasPureTensorSemantics())
861 destTensorOrMemref = smeWrite.getResult();
864 if (writeOp.hasPureTensorSemantics())
865 rewriter.
replaceOp(writeOp, destTensorOrMemref);
904struct LowerColumnTransferReadToLoops
908 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
909 PatternRewriter &rewriter)
const override {
912 if (readOp.hasPureTensorSemantics())
914 readOp,
"Tensor semantics are unsupported (either bufferize or "
915 "extend this pattern)");
917 auto resType = readOp.getVectorType();
919 if (resType.getRank() != 2)
921 "Only 2D vectors are supported!");
923 if (resType.getShape()[1] != 1)
925 readOp,
"The trailing output dim is != 1 (not supported ATM)");
927 if (!resType.getScalableDims()[0] || resType.getScalableDims()[1])
929 readOp,
"Expected the leading dim to be scalable and the trailing "
934 int64_t numRows = resType.getShape()[0];
935 VectorType newResType = VectorType::get(numRows, resType.getElementType(),
939 auto loc = readOp.getLoc();
941 auto createVscaleMultiple =
943 auto upperBound = createVscaleMultiple(numRows);
945 Value init = arith::ConstantOp::create(
950 OpBuilder::InsertionGuard g(rewriter);
951 loadLoop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
955 auto tileSliceIndex = loadLoop.getInductionVar();
957 auto idx0 = arith::AddIOp::create(rewriter, loc, tileSliceIndex,
958 readOp.getIndices()[0]);
959 auto idx1 = readOp.getIndices()[1];
961 Value scalar = memref::LoadOp::create(rewriter, loc, readOp.getBase(),
962 SmallVector<Value>({idx0, idx1}));
964 Operation *updateInit = vector::InsertOp::create(
965 rewriter, loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex);
967 scf::YieldOp::create(rewriter, loc, updateInit->
getResult(0));
974 auto sc = vector::ShapeCastOp::create(
975 rewriter, loc, readOp.getResult().getType(), loadLoop.getResult(0));
983struct VectorLegalizationPass
985 void runOnOperation()
override {
987 TypeConverter converter;
988 RewritePatternSet
patterns(context);
989 converter.addConversion([](Type type) {
return type; });
990 converter.addConversion(
991 [](VectorType vectorType,
992 SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
995 auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
998 types = SmallVector<Type>(smeTileCount, smeTileType);
1003 RewritePatternSet rewritePatterns(context);
1005 .add<FoldExtractFromVectorOfSMELikeCreateMasks,
1006 LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory,
1007 LowerIllegalTransposeStoreViaZA>(context);
1010 return signalPassFailure();
1015 patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
1016 LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
1018 patterns.add<LegalizeArithConstantOpsByDecomposition,
1019 LegalizeVectorOuterProductOpsByDecomposition,
1020 LegalizeTransferReadOpsByDecomposition,
1021 LegalizeTransferWriteOpsByDecomposition>(converter, context);
1022 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
1029 target.markUnknownOpDynamicallyLegal(
1030 [&](Operation *op) {
return converter.isLegal(op); });
1031 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1032 return converter.isSignatureLegal(op.getFunctionType());
1034 if (
failed(applyPartialConversion(getOperation(),
target,
1036 return signalPassFailure();
1043 return std::make_unique<VectorLegalizationPass>();
static Value createMask(AffineForOp vecForOp, VectorizationState &state)
Creates a mask used to filter out garbage elements in the last iteration of unaligned loops.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getIndexAttr(int64_t value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
OperationName getName()
The name of an operation is the key identifier for it.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
std::unique_ptr< Pass > createVectorLegalizationPass()
Pass that legalizes vectors so they can be lowered to ArmSME.
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Include the generated interface declarations.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...