16#include "llvm/Support/Casting.h"
45struct TransferReadToArmSMELowering
49 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
50 PatternRewriter &rewriter)
const final {
52 if (transferReadOp.getTransferRank() != 2)
53 return rewriter.notifyMatchFailure(transferReadOp,
54 "not a 2 result permutation map");
56 auto vectorType = transferReadOp.getVectorType();
58 return rewriter.notifyMatchFailure(transferReadOp,
59 "not a valid vector type for SME");
61 if (!llvm::isa<MemRefType>(transferReadOp.getBase().getType()))
62 return rewriter.notifyMatchFailure(transferReadOp,
"not a memref source");
65 if (transferReadOp.hasOutOfBoundsDim())
66 return rewriter.notifyMatchFailure(transferReadOp,
67 "not inbounds transfer read");
69 AffineMap map = transferReadOp.getPermutationMap();
71 return rewriter.notifyMatchFailure(transferReadOp,
72 "unsupported permutation map");
77 arm_sme::TileSliceLayout layout =
78 transposed ? arm_sme::TileSliceLayout::Vertical
79 : arm_sme::TileSliceLayout::Horizontal;
84 auto mask = transferReadOp.getMask();
85 auto padding = mask ? transferReadOp.getPadding() :
nullptr;
86 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
87 transferReadOp, vectorType, transferReadOp.getBase(),
88 transferReadOp.getIndices(), padding, mask, layout);
121struct TransferWriteToArmSMELowering
125 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
126 PatternRewriter &rewriter)
const final {
127 auto vType = writeOp.getVectorType();
131 if (!llvm::isa<MemRefType>(writeOp.getBase().getType()))
135 if (writeOp.hasOutOfBoundsDim())
136 return rewriter.notifyMatchFailure(writeOp,
137 "not inbounds transfer write");
141 return rewriter.notifyMatchFailure(writeOp,
142 "unsupported permutation map");
147 arm_sme::TileSliceLayout layout =
148 transposed ? arm_sme::TileSliceLayout::Vertical
149 : arm_sme::TileSliceLayout::Horizontal;
151 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
152 writeOp, writeOp.getVector(), writeOp.getBase(), writeOp.getIndices(),
153 writeOp.getMask(), layout);
159struct VectorLoadToArmSMELowering :
public OpRewritePattern<vector::LoadOp> {
162 LogicalResult matchAndRewrite(vector::LoadOp
load,
163 PatternRewriter &rewriter)
const override {
175struct VectorStoreToArmSMELowering :
public OpRewritePattern<vector::StoreOp> {
178 LogicalResult matchAndRewrite(vector::StoreOp store,
179 PatternRewriter &rewriter)
const override {
184 store, store.getValueToStore(), store.getBase(), store.getIndices());
209struct BroadcastOpToArmSMELowering
213 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
214 PatternRewriter &rewriter)
const final {
215 auto tileType = broadcastOp.getResultVectorType();
219 auto loc = broadcastOp.getLoc();
221 auto srcType = broadcastOp.getSourceType();
222 auto srcVectorType = dyn_cast<VectorType>(srcType);
225 if (srcType.isIntOrFloat() ||
226 (srcVectorType && (srcVectorType.getRank() == 0))) {
228 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
229 broadcastOp1D = vector::BroadcastOp::create(rewriter, loc, tileSliceType,
230 broadcastOp.getSource());
231 }
else if (srcVectorType && (srcVectorType.getRank() == 1))
233 broadcastOp1D = broadcastOp.getSource();
237 auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
239 auto makeLoopBody = [&](OpBuilder &
b, Location loc, Value tileSliceIndex,
243 auto nextTile = arm_sme::InsertTileSliceOp::create(
244 b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
245 return nextTile.getResult();
252 rewriter.
replaceOp(broadcastOp, forOp.getResult(0));
280struct TransposeOpToArmSMELowering
284 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
285 PatternRewriter &rewriter)
const final {
286 auto tileType = transposeOp.getResultVectorType();
291 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
292 if (permutation[0] != 1 || permutation[1] != 0)
295 auto loc = transposeOp.getLoc();
296 Value input = transposeOp.getVector();
298 if (
auto xferOp = input.
getDefiningOp<vector::TransferReadOp>();
303 xferOp->setAttr(xferOp.getPermutationMapAttrName(),
305 permutation, transposeOp.getContext())));
313 vector::VectorScaleOp::create(rewriter, loc, rewriter.
getIndexType());
314 Value minTileSlices = arith::ConstantOp::create(
315 rewriter, loc, rewriter.
getIndexAttr(tileType.getDimSize(0)));
317 arith::ConstantOp::create(rewriter, loc, rewriter.
getIndexAttr(0));
318 Value numTileSlices =
319 arith::MulIOp::create(rewriter, loc, vscale, minTileSlices);
321 MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
322 tileType.getElementType());
323 auto buffer = memref::AllocaOp::create(
324 rewriter, loc, bufferType,
ValueRange{numTileSlices, numTileSlices});
327 auto tileStoreOp = arm_sme::TileStoreOp::create(rewriter, loc, input,
332 transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
333 arm_sme::TileSliceLayout::Vertical);
372struct VectorOuterProductToArmSMELowering
377 LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
378 PatternRewriter &rewriter)
const override {
382 if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
384 "AXPY operations not supported");
387 outerProductOp.getResultVectorType()))
389 outerProductOp,
"outer product does not fit into SME tile");
391 auto kind = outerProductOp.getKind();
392 if (kind != vector::CombiningKind::ADD)
395 "unsupported kind (lowering to SME only supports ADD at the moment)");
399 Operation *rootOp = outerProductOp;
400 auto loc = outerProductOp.getLoc();
401 if (outerProductOp.isMasked()) {
402 auto maskOp = outerProductOp.getMaskingOp();
405 auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
408 std::tie(lhsMask, rhsMask) = *operandMasks;
412 rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
413 outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
418 static FailureOr<std::pair<Value, Value>>
419 decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
422 auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>();
426 auto maskType = createMaskOp.getVectorType();
427 Value lhsMaskDim = createMaskOp.getOperand(0);
428 Value rhsMaskDim = createMaskOp.getOperand(1);
430 VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
431 Value lhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
433 Value rhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
436 return std::make_pair(lhsMask, rhsMask);
452struct VectorExtractToArmSMELowering
456 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
457 PatternRewriter &rewriter)
const override {
458 VectorType sourceType = extractOp.getSourceVectorType();
462 auto loc = extractOp.getLoc();
463 auto position = extractOp.getMixedPosition();
465 Value sourceVector = extractOp.getSource();
468 if (position.empty()) {
469 rewriter.
replaceOp(extractOp, sourceVector);
474 auto extractTileSlice = arm_sme::ExtractTileSliceOp::create(
475 rewriter, loc, sourceVector, sliceIndex);
477 if (position.size() == 1) {
479 rewriter.
replaceOp(extractOp, extractTileSlice);
484 assert(position.size() == 2);
508struct VectorInsertToArmSMELowering
512 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
513 PatternRewriter &rewriter)
const override {
514 VectorType resultType = insertOp.getResult().getType();
519 auto loc = insertOp.getLoc();
520 auto position = insertOp.getMixedPosition();
522 Value source = insertOp.getValueToStore();
526 if (position.empty()) {
531 Value tileSlice = source;
533 if (position.size() == 2) {
536 tileSlice = arm_sme::ExtractTileSliceOp::create(
537 rewriter, loc, insertOp.getDest(), sliceIndex);
538 tileSlice = vector::InsertOp::create(rewriter, loc, source, tileSlice,
544 insertOp, tileSlice, insertOp.getDest(), sliceIndex);
570struct VectorPrintToArmSMELowering :
public OpRewritePattern<vector::PrintOp> {
573 LogicalResult matchAndRewrite(vector::PrintOp
printOp,
574 PatternRewriter &rewriter)
const override {
578 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
585 auto vscale = vector::VectorScaleOp::create(rewriter, loc);
589 auto upperBound = arith::MulIOp::create(rewriter, loc, minTileRows, vscale);
592 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
597 Value rowIndex = forOp.getInductionVar();
598 auto tileSlice = arm_sme::ExtractTileSliceOp::create(
599 rewriter, loc,
printOp.getSource(), rowIndex);
601 vector::PrintOp::create(rewriter, loc, tileSlice,
624struct FoldTransferWriteOfExtractTileSlice
628 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
629 PatternRewriter &rewriter)
const final {
630 if (!isa<MemRefType>(writeOp.getBase().getType()))
633 if (writeOp.hasOutOfBoundsDim())
635 "not inbounds transfer write");
637 auto extractTileSlice =
638 writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>();
639 if (!extractTileSlice)
641 writeOp,
"vector to store not from ExtractTileSliceOp");
646 "unsupported permutation map");
648 Value mask = writeOp.getMask();
650 auto maskType = writeOp.getVectorType().clone(rewriter.
getI1Type());
651 mask = arith::ConstantOp::create(rewriter, writeOp.getLoc(), maskType,
656 writeOp, extractTileSlice.getTile(),
657 extractTileSlice.getTileSliceIndex(), mask, writeOp.getBase(),
658 writeOp.getIndices(), extractTileSlice.getLayout());
680struct ExtractFromCreateMaskToPselLowering
684 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
685 PatternRewriter &rewriter)
const override {
686 if (extractOp.getNumIndices() != 1)
689 auto resultType = extractOp.getResult().getType();
690 auto resultVectorType = dyn_cast<VectorType>(resultType);
691 if (!resultVectorType)
695 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
699 auto maskType = createMaskOp.getVectorType();
700 if (maskType.getRank() != 2 || !maskType.allDimsScalable())
703 auto isSVEPredicateSize = [](int64_t size) {
704 return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
707 auto rowsBaseSize = maskType.getDimSize(0);
708 auto colsBaseSize = maskType.getDimSize(1);
709 if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
711 createMaskOp,
"mask dimensions not SVE predicate-sized");
713 auto loc = extractOp.getLoc();
714 VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
715 VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
720 auto rowMask = vector::CreateMaskOp::create(rewriter, loc, rowMaskType,
721 createMaskOp.getOperand(0));
722 auto colMask = vector::CreateMaskOp::create(rewriter, loc, colMaskType,
723 createMaskOp.getOperand(1));
738 patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering,
739 TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
740 VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
741 VectorOuterProductToArmSMELowering,
742 VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
743 VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
744 ExtractFromCreateMaskToPselLowering>(&ctx);
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
IntegerAttr getIndexAttr(int64_t value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
bool hasOneUse()
Returns true if this operation has exactly one use.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx)
Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...