16 #include "llvm/Support/Casting.h"
45 struct TransferReadToArmSMELowering
49 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
52 if (transferReadOp.getTransferRank() != 2)
53 return rewriter.notifyMatchFailure(transferReadOp,
54 "not a 2 result permutation map");
56 auto vectorType = transferReadOp.getVectorType();
58 return rewriter.notifyMatchFailure(transferReadOp,
59 "not a valid vector type for SME");
61 if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
62 return rewriter.notifyMatchFailure(transferReadOp,
"not a memref source");
65 if (transferReadOp.hasOutOfBoundsDim())
66 return rewriter.notifyMatchFailure(transferReadOp,
67 "not inbounds transfer read");
71 return rewriter.notifyMatchFailure(transferReadOp,
72 "unsupported permutation map");
77 arm_sme::TileSliceLayout layout =
78 transposed ? arm_sme::TileSliceLayout::Vertical
79 : arm_sme::TileSliceLayout::Horizontal;
84 auto mask = transferReadOp.getMask();
85 auto padding = mask ? transferReadOp.getPadding() :
nullptr;
86 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
87 transferReadOp, vectorType, transferReadOp.getSource(),
88 transferReadOp.getIndices(), padding, mask, layout);
121 struct TransferWriteToArmSMELowering
125 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
127 auto vType = writeOp.getVectorType();
131 if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
135 if (writeOp.hasOutOfBoundsDim())
136 return rewriter.notifyMatchFailure(writeOp,
137 "not inbounds transfer write");
141 return rewriter.notifyMatchFailure(writeOp,
142 "unsupported permutation map");
147 arm_sme::TileSliceLayout layout =
148 transposed ? arm_sme::TileSliceLayout::Vertical
149 : arm_sme::TileSliceLayout::Horizontal;
151 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
152 writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
153 writeOp.getMask(), layout);
159 struct VectorLoadToArmSMELowering :
public OpRewritePattern<vector::LoadOp> {
162 LogicalResult matchAndRewrite(vector::LoadOp load,
168 load, load.getVectorType(), load.getBase(), load.getIndices());
175 struct VectorStoreToArmSMELowering :
public OpRewritePattern<vector::StoreOp> {
178 LogicalResult matchAndRewrite(vector::StoreOp store,
184 store, store.getValueToStore(), store.getBase(), store.getIndices());
209 struct BroadcastOpToArmSMELowering
213 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
215 auto tileType = broadcastOp.getResultVectorType();
219 auto loc = broadcastOp.getLoc();
221 auto srcType = broadcastOp.getSourceType();
222 auto srcVectorType = dyn_cast<VectorType>(srcType);
225 if (srcType.isIntOrFloat() ||
226 (srcVectorType && (srcVectorType.getRank() == 0))) {
229 broadcastOp1D = rewriter.
create<vector::BroadcastOp>(
230 loc, tileSliceType, broadcastOp.getSource());
231 }
else if (srcVectorType && (srcVectorType.getRank() == 1))
233 broadcastOp1D = broadcastOp.getSource();
237 auto initTile = rewriter.
create<arm_sme::GetTileOp>(loc, tileType);
243 auto nextTile = b.
create<arm_sme::InsertTileSliceOp>(
244 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
252 rewriter.
replaceOp(broadcastOp, forOp.getResult(0));
280 LogicalResult matchAndRewrite(vector::SplatOp splatOp,
282 auto tileType = splatOp.getResult().getType();
286 auto loc = splatOp.getLoc();
287 auto srcType = splatOp.getOperand().getType();
289 assert(srcType.isIntOrFloat() &&
"Invalid source type for vector.splat");
295 Value broadcastOp1D = rewriter.
create<vector::BroadcastOp>(
296 loc, tileSliceType, splatOp.getInput());
298 auto initTile = rewriter.
create<arm_sme::GetTileOp>(loc, tileType);
302 auto nextTile = b.
create<arm_sme::InsertTileSliceOp>(
303 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
312 rewriter.
replaceOp(splatOp, forOp.getResult(0));
340 struct TransposeOpToArmSMELowering
344 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
346 auto tileType = transposeOp.getResultVectorType();
352 if (permutation[0] != 1 || permutation[1] != 0)
355 auto loc = transposeOp.getLoc();
356 Value input = transposeOp.getVector();
358 if (
auto xferOp = input.
getDefiningOp<vector::TransferReadOp>();
363 xferOp->setAttr(xferOp.getPermutationMapAttrName(),
365 permutation, transposeOp.getContext())));
374 Value minTileSlices = rewriter.
create<arith::ConstantOp>(
378 Value numTileSlices =
379 rewriter.
create<arith::MulIOp>(loc, vscale, minTileSlices);
382 tileType.getElementType());
383 auto buffer = rewriter.
create<memref::AllocaOp>(
384 loc, bufferType,
ValueRange{numTileSlices, numTileSlices});
387 auto tileStoreOp = rewriter.
create<arm_sme::TileStoreOp>(
392 transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
393 arm_sme::TileSliceLayout::Vertical);
432 struct VectorOuterProductToArmSMELowering
437 LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
442 if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
444 "AXPY operations not supported");
447 outerProductOp.getResultVectorType()))
449 outerProductOp,
"outer product does not fit into SME tile");
451 auto kind = outerProductOp.getKind();
452 if (kind != vector::CombiningKind::ADD)
455 "unsupported kind (lowering to SME only supports ADD at the moment)");
460 auto loc = outerProductOp.
getLoc();
461 if (outerProductOp.isMasked()) {
462 auto maskOp = outerProductOp.getMaskingOp();
465 auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
466 if (failed(operandMasks))
468 std::tie(lhsMask, rhsMask) = *operandMasks;
472 rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
473 outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
478 static FailureOr<std::pair<Value, Value>>
482 auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>();
486 auto maskType = createMaskOp.getVectorType();
487 Value lhsMaskDim = createMaskOp.getOperand(0);
488 Value rhsMaskDim = createMaskOp.getOperand(1);
492 rewriter.
create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
494 rewriter.
create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
496 return std::make_pair(lhsMask, rhsMask);
512 struct VectorExtractToArmSMELowering
516 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
518 VectorType sourceType = extractOp.getSourceVectorType();
522 auto loc = extractOp.getLoc();
523 auto position = extractOp.getMixedPosition();
525 Value sourceVector = extractOp.getVector();
528 if (position.empty()) {
529 rewriter.
replaceOp(extractOp, sourceVector);
534 auto extractTileSlice = rewriter.
create<arm_sme::ExtractTileSliceOp>(
535 loc, sourceVector, sliceIndex);
537 if (position.size() == 1) {
539 rewriter.
replaceOp(extractOp, extractTileSlice);
544 assert(position.size() == 2);
568 struct VectorInsertToArmSMELowering
572 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
574 VectorType resultType = insertOp.getResult().getType();
579 auto loc = insertOp.getLoc();
580 auto position = insertOp.getMixedPosition();
582 Value source = insertOp.getSource();
586 if (position.empty()) {
591 Value tileSlice = source;
593 if (position.size() == 2) {
596 tileSlice = rewriter.
create<arm_sme::ExtractTileSliceOp>(
597 loc, insertOp.getDest(), sliceIndex);
598 tileSlice = rewriter.
create<vector::InsertOp>(loc, source, tileSlice,
604 insertOp, tileSlice, insertOp.getDest(), sliceIndex);
630 struct VectorPrintToArmSMELowering :
public OpRewritePattern<vector::PrintOp> {
633 LogicalResult matchAndRewrite(vector::PrintOp
printOp,
638 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
645 auto vscale = rewriter.
create<vector::VectorScaleOp>(loc);
647 rewriter.
create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
648 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
649 auto upperBound = rewriter.
create<arith::MulIOp>(loc, minTileRows, vscale);
650 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
651 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
656 Value rowIndex = forOp.getInductionVar();
657 auto tileSlice = rewriter.
create<arm_sme::ExtractTileSliceOp>(
658 loc,
printOp.getSource(), rowIndex);
660 rewriter.
create<vector::PrintOp>(loc, tileSlice,
683 struct FoldTransferWriteOfExtractTileSlice
687 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
689 if (!isa<MemRefType>(writeOp.getSource().getType()))
692 if (writeOp.hasOutOfBoundsDim())
694 "not inbounds transfer write");
696 auto extractTileSlice =
697 writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>();
698 if (!extractTileSlice)
700 writeOp,
"vector to store not from ExtractTileSliceOp");
705 "unsupported permutation map");
707 Value mask = writeOp.getMask();
709 auto maskType = writeOp.getVectorType().clone(rewriter.
getI1Type());
710 mask = rewriter.
create<arith::ConstantOp>(
715 writeOp, extractTileSlice.getTile(),
716 extractTileSlice.getTileSliceIndex(), mask, writeOp.getSource(),
717 writeOp.getIndices(), extractTileSlice.getLayout());
739 struct ExtractFromCreateMaskToPselLowering
743 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
745 if (extractOp.getNumIndices() != 1)
748 auto resultType = extractOp.getResult().getType();
749 auto resultVectorType = dyn_cast<VectorType>(resultType);
750 if (!resultVectorType)
754 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
758 auto maskType = createMaskOp.getVectorType();
759 if (maskType.getRank() != 2 || !maskType.allDimsScalable())
762 auto isSVEPredicateSize = [](int64_t size) {
763 return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
766 auto rowsBaseSize = maskType.getDimSize(0);
767 auto colsBaseSize = maskType.getDimSize(1);
768 if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
770 createMaskOp,
"mask dimensions not SVE predicate-sized");
772 auto loc = extractOp.getLoc();
779 auto rowMask = rewriter.
create<vector::CreateMaskOp>(
780 loc, rowMaskType, createMaskOp.getOperand(0));
781 auto colMask = rewriter.
create<vector::CreateMaskOp>(
782 loc, colMaskType, createMaskOp.getOperand(1));
797 patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
798 TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
799 TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
800 VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
801 VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
802 VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
803 ExtractFromCreateMaskToPselLowering>(&ctx);
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
IntegerAttr getIndexAttr(int64_t value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx)
Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...