26 #define DEBUG_TYPE "arm-sme-vector-legalization"
29 #define GEN_PASS_DEF_VECTORLEGALIZATION
30 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
43 static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
44 "op vector size is not multiple of SME tiles");
45 static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
46 "op mask is unsupported for legalization/decomposition");
47 static constexpr StringLiteral
48 kMatchFailureNonPermutationMap(
"op affine map is not a permutation");
49 static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
50 "expected transpose from illegal type to legal type");
80 auto vscale = builder.
create<vector::VectorScaleOp>(loc);
81 return llvm::map_to_vector(
82 llvm::zip_equal(indices, scalableOffsets), [&](
auto pair) ->
Value {
83 auto [index, base] = pair;
84 auto offset = builder.
create<arith::MulIOp>(
85 loc, builder.
create<arith::ConstantIndexOp>(loc, base), vscale);
86 return builder.
create<arith::AddIOp>(loc, index, offset);
106 SMESubTile smeTile) {
107 return addConstantScalableOffset(builder, loc, indices,
108 {smeTile.row, smeTile.col});
114 bool isSupportedMaskOp(
Value mask) {
120 SMESubTile smeTile) {
121 assert(isSupportedMaskOp(mask));
128 auto smeTileMaskDims = addConstantScalableOffset(
129 builder, loc,
createMask.getOperands(), {-smeTile.row, -smeTile.col});
130 auto smeTileCreateMask = builder.
create<vector::CreateMaskOp>(
131 loc, smeTile.type.clone(builder.
getI1Type()), smeTileMaskDims);
139 auto decomposeToSMETiles(
OpBuilder &builder, VectorType type,
140 VectorType smeTileType,
141 bool transposeIndices =
false) {
143 "`type` not multiple of SME tiles");
144 return llvm::map_range(
146 smeTileType.getDimSize(1)}),
148 int row = int(indices[0]);
149 int col = int(indices[1]);
150 if (transposeIndices)
152 return SMESubTile{row, col, smeTileType};
158 int getNumberOfSMETilesForVectorType(VectorType type) {
160 "`type` not multiple of SME tiles");
161 int64_t vectorRows = type.getDimSize(0);
162 int64_t vectorCols = type.getDimSize(1);
163 auto elementType = type.getElementType();
165 return (vectorRows * vectorCols) / (minNumElts * minNumElts);
170 struct LegalizeArithConstantOpsByDecomposition
175 matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
177 auto vectorType = dyn_cast<VectorType>(constantOp.getType());
178 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
179 if (!vectorType || !denseAttr || !denseAttr.isSplat())
184 kMatchFailureNotSMETileTypeMultiple);
187 auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
188 auto tileSplat = rewriter.
create<arith::ConstantOp>(
189 constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
191 adaptor.getResultMapping());
199 struct LegalizeVectorOuterProductOpsByDecomposition
204 matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
206 auto vectorType = outerProductOp.getResultVectorType();
209 kMatchFailureNotSMETileTypeMultiple);
213 auto loc = outerProductOp.
getLoc();
214 if (outerProductOp.isMasked()) {
215 auto maskOp = outerProductOp.getMaskingOp();
216 mask = maskOp.getMask();
220 if (!isSupportedMaskOp(mask))
222 kMatchFailureUnsupportedMaskOp);
230 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
232 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
233 auto lhs = rewriter.
create<vector::ScalableExtractOp>(
234 loc, sliceType, outerProductOp.getLhs(), smeTile.row);
235 auto rhs = rewriter.
create<vector::ScalableExtractOp>(
236 loc, sliceType, outerProductOp.getRhs(), smeTile.col);
237 auto smeOuterProduct = rewriter.
create<vector::OuterProductOp>(
238 loc, smeTileType, lhs, rhs,
239 !accSMETiles.empty() ? accSMETiles[index] :
Value{},
240 outerProductOp.getKind());
242 auto maskedOuterProduct =
244 resultSMETiles.push_back(maskedOuterProduct->getResult(0));
247 rewriter.
replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
257 struct LegalizeMaskedVectorOuterProductOpsByDecomposition
262 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
264 if (
auto outerProductOp =
265 llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
266 LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
269 outerProductOp, rewriter);
277 struct LegalizeTransferReadOpsByDecomposition
282 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
284 auto vectorType = readOp.getVectorType();
287 kMatchFailureNotSMETileTypeMultiple);
289 auto mask = readOp.getMask();
290 if (!isSupportedMaskOp(mask))
292 kMatchFailureUnsupportedMaskOp);
294 auto permutationMap = readOp.getPermutationMap();
295 if (!permutationMap.isPermutation())
297 kMatchFailureNonPermutationMap);
301 bool transposed = !permutationMap.isIdentity();
303 auto loc = readOp.getLoc();
307 for (SMESubTile smeTile :
308 decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
309 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
310 auto smeRead = rewriter.
create<vector::TransferReadOp>(
311 loc, smeTileType, readOp.getSource(),
312 getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
313 readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
314 readOp.getInBoundsAttr());
315 resultSMETiles.push_back(smeRead);
318 rewriter.
replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
325 struct LegalizeTransferWriteOpsByDecomposition
330 matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
332 auto vectorType = writeOp.getVectorType();
335 kMatchFailureNotSMETileTypeMultiple);
337 auto mask = writeOp.getMask();
338 if (!isSupportedMaskOp(mask))
340 kMatchFailureUnsupportedMaskOp);
342 auto permutationMap = writeOp.getPermutationMap();
343 if (!permutationMap.isPermutation())
345 kMatchFailureNonPermutationMap);
349 bool transposed = !permutationMap.isIdentity();
351 auto loc = writeOp.getLoc();
353 auto inputSMETiles = adaptor.getVector();
355 Value destTensorOrMemref = writeOp.getSource();
357 rewriter, vectorType, smeTileType, transposed))) {
358 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
359 auto smeWrite = rewriter.
create<vector::TransferWriteOp>(
360 loc, inputSMETiles[index], destTensorOrMemref,
361 getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
362 writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
363 if (writeOp.hasPureTensorSemantics())
364 destTensorOrMemref = smeWrite.
getResult();
367 if (writeOp.hasPureTensorSemantics())
368 rewriter.
replaceOp(writeOp, destTensorOrMemref);
399 struct FoldExtractFromVectorOfSMELikeCreateMasks
405 auto loc = extractOp.getLoc();
407 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
410 extractOp,
"extract not from vector.create_mask op");
412 VectorType extractedMaskType =
413 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
414 if (!extractedMaskType)
416 "extracted type is not a vector type");
418 auto numScalable = llvm::count(extractedMaskType.getScalableDims(),
true);
419 if (numScalable != 2)
421 extractOp,
"expected extracted type to be an SME-like mask");
424 if (extractOp.getStaticPosition().size() != 1)
426 extractOp,
"only a single extraction index is supported");
428 auto frontMaskDim = createMaskOp.getOperand(0);
429 if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
432 "constant vector.create_masks dims should be folded elsewhere");
434 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
436 rewriter, loc, extractOp.getMixedPosition()[0]);
437 auto extractionInTrueRegion = rewriter.
create<arith::CmpIOp>(
438 loc, rewriter.
getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
440 auto newMaskFrontDim = rewriter.
create<arith::SelectOp>(
441 loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
444 extractOp, extractedMaskType,
445 ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
451 bool isLegalVectorType(VectorType vType) {
452 bool seenFixedDim =
false;
453 for (
bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
454 seenFixedDim |= !scalableFlag;
455 if (seenFixedDim && scalableFlag)
489 struct LiftIllegalVectorTransposeToMemory
494 if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
499 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
501 auto sourceType = transposeOp.getSourceVectorType();
502 auto resultType = transposeOp.getResultVectorType();
503 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
505 kMatchFailureNotIllegalToLegal);
508 Value maybeRead = transposeOp.getVector();
511 if (
Value extendSource = getExtensionSource(transposeSourceOp)) {
512 maybeRead = extendSource;
513 extendOp = transposeSourceOp;
516 auto illegalRead = maybeRead.
getDefiningOp<vector::TransferReadOp>();
520 "expected source to be (possibly extended) transfer_read");
522 if (!illegalRead.getPermutationMap().isIdentity())
524 illegalRead,
"expected read to have identity permutation map");
526 auto loc = transposeOp.getLoc();
527 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
528 auto one = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
531 auto readType = illegalRead.getVectorType();
532 auto readSizes = llvm::map_to_vector(
533 llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
534 [&](
auto dim) ->
Value {
535 auto [size, isScalable] = dim;
536 auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
539 auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
540 return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
543 auto readSubview = rewriter.
create<memref::SubViewOp>(
544 loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
549 Value mask = illegalRead.getMask();
553 mask = rewriter.
create<vector::TransposeOp>(loc, mask,
554 transposeOp.getPermutation());
559 auto transposedSubview = rewriter.
create<memref::TransposeOp>(
561 ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
570 VectorType legalReadType = resultType.clone(readType.getElementType());
573 auto legalRead = rewriter.
create<vector::TransferReadOp>(
574 loc, legalReadType, transposedSubview, readIndices,
575 illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
583 Value(legalRead), resultType);
611 struct ConvertIllegalShapeCastOpsToTransposes
615 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
617 auto sourceType = shapeCastOp.getSourceVectorType();
618 auto resultType = shapeCastOp.getResultVectorType();
619 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
621 kMatchFailureNotIllegalToLegal);
625 if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
627 shapeCastOp,
"expected source to be a 2D scalable vector with a "
628 "trailing unit dim");
630 auto loc = shapeCastOp.getLoc();
634 if (resultType.getRank() == 1)
644 struct VectorLegalizationPass
645 :
public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
646 void runOnOperation()
override {
652 [](VectorType vectorType,
656 auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
663 patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
664 LiftIllegalVectorTransposeToMemory,
665 ConvertIllegalShapeCastOpsToTransposes>(context);
667 patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
668 converter, context, 1024);
669 patterns.add<LegalizeArithConstantOpsByDecomposition,
670 LegalizeVectorOuterProductOpsByDecomposition,
671 LegalizeTransferReadOpsByDecomposition,
672 LegalizeTransferWriteOpsByDecomposition>(converter, context);
677 std::move(patterns))))
678 return signalPassFailure();
685 return std::make_unique<VectorLegalizationPass>();
static MLIRContext * getContext(OpFoldResult val)
static Value createMask(AffineForOp vecForOp, VectorizationState &state)
Creates a mask used to filter out garbage elements in the last iteration of unaligned loops.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class is a wrapper around OneToNConversionPattern for matching against instances of a particular...
OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Specialization of PatternRewriter that OneToNConversionPatterns use.
void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping)
Replaces the results of the operation with the specified list of values mapped back to the original t...
Extends TypeConverter with 1:N target materializations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperationName getName()
The name of an operation is the key identifier for it.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
std::unique_ptr< Pass > createVectorLegalizationPass()
Pass that legalizes vectors so they can be lowered to ArmSME.
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the provided pattern set with patterns that do 1:N type conversions on (some) SCF ops.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...