16 #include "llvm/Support/Casting.h"
45 struct TransferReadToArmSMELowering
49 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
52 if (transferReadOp.getTransferRank() != 2)
53 return rewriter.notifyMatchFailure(transferReadOp,
54 "not a 2 result permutation map");
56 auto vectorType = transferReadOp.getVectorType();
58 return rewriter.notifyMatchFailure(transferReadOp,
59 "not a valid vector type for SME");
61 if (!llvm::isa<MemRefType>(transferReadOp.getBase().getType()))
62 return rewriter.notifyMatchFailure(transferReadOp,
"not a memref source");
65 if (transferReadOp.hasOutOfBoundsDim())
66 return rewriter.notifyMatchFailure(transferReadOp,
67 "not inbounds transfer read");
71 return rewriter.notifyMatchFailure(transferReadOp,
72 "unsupported permutation map");
77 arm_sme::TileSliceLayout layout =
78 transposed ? arm_sme::TileSliceLayout::Vertical
79 : arm_sme::TileSliceLayout::Horizontal;
84 auto mask = transferReadOp.getMask();
85 auto padding = mask ? transferReadOp.getPadding() :
nullptr;
86 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
87 transferReadOp, vectorType, transferReadOp.getBase(),
88 transferReadOp.getIndices(), padding, mask, layout);
121 struct TransferWriteToArmSMELowering
125 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
127 auto vType = writeOp.getVectorType();
131 if (!llvm::isa<MemRefType>(writeOp.getBase().getType()))
135 if (writeOp.hasOutOfBoundsDim())
136 return rewriter.notifyMatchFailure(writeOp,
137 "not inbounds transfer write");
141 return rewriter.notifyMatchFailure(writeOp,
142 "unsupported permutation map");
147 arm_sme::TileSliceLayout layout =
148 transposed ? arm_sme::TileSliceLayout::Vertical
149 : arm_sme::TileSliceLayout::Horizontal;
151 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
152 writeOp, writeOp.getVector(), writeOp.getBase(), writeOp.getIndices(),
153 writeOp.getMask(), layout);
159 struct VectorLoadToArmSMELowering :
public OpRewritePattern<vector::LoadOp> {
162 LogicalResult matchAndRewrite(vector::LoadOp load,
168 load, load.getVectorType(), load.getBase(), load.getIndices());
175 struct VectorStoreToArmSMELowering :
public OpRewritePattern<vector::StoreOp> {
178 LogicalResult matchAndRewrite(vector::StoreOp store,
184 store, store.getValueToStore(), store.getBase(), store.getIndices());
209 struct BroadcastOpToArmSMELowering
213 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
215 auto tileType = broadcastOp.getResultVectorType();
219 auto loc = broadcastOp.getLoc();
221 auto srcType = broadcastOp.getSourceType();
222 auto srcVectorType = dyn_cast<VectorType>(srcType);
225 if (srcType.isIntOrFloat() ||
226 (srcVectorType && (srcVectorType.getRank() == 0))) {
229 broadcastOp1D = vector::BroadcastOp::create(rewriter, loc, tileSliceType,
230 broadcastOp.getSource());
231 }
else if (srcVectorType && (srcVectorType.getRank() == 1))
233 broadcastOp1D = broadcastOp.getSource();
237 auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
243 auto nextTile = arm_sme::InsertTileSliceOp::create(
244 b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
245 return nextTile.getResult();
252 rewriter.
replaceOp(broadcastOp, forOp.getResult(0));
280 struct TransposeOpToArmSMELowering
284 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
286 auto tileType = transposeOp.getResultVectorType();
292 if (permutation[0] != 1 || permutation[1] != 0)
295 auto loc = transposeOp.getLoc();
296 Value input = transposeOp.getVector();
298 if (
auto xferOp = input.
getDefiningOp<vector::TransferReadOp>();
303 xferOp->setAttr(xferOp.getPermutationMapAttrName(),
305 permutation, transposeOp.getContext())));
313 vector::VectorScaleOp::create(rewriter, loc, rewriter.
getIndexType());
314 Value minTileSlices = arith::ConstantOp::create(
315 rewriter, loc, rewriter.
getIndexAttr(tileType.getDimSize(0)));
317 arith::ConstantOp::create(rewriter, loc, rewriter.
getIndexAttr(0));
318 Value numTileSlices =
319 arith::MulIOp::create(rewriter, loc, vscale, minTileSlices);
322 tileType.getElementType());
323 auto buffer = memref::AllocaOp::create(
324 rewriter, loc, bufferType,
ValueRange{numTileSlices, numTileSlices});
327 auto tileStoreOp = arm_sme::TileStoreOp::create(rewriter, loc, input,
332 transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
333 arm_sme::TileSliceLayout::Vertical);
372 struct VectorOuterProductToArmSMELowering
377 LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
382 if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
384 "AXPY operations not supported");
387 outerProductOp.getResultVectorType()))
389 outerProductOp,
"outer product does not fit into SME tile");
391 auto kind = outerProductOp.getKind();
392 if (
kind != vector::CombiningKind::ADD)
395 "unsupported kind (lowering to SME only supports ADD at the moment)");
400 auto loc = outerProductOp.
getLoc();
401 if (outerProductOp.isMasked()) {
402 auto maskOp = outerProductOp.getMaskingOp();
405 auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
406 if (failed(operandMasks))
408 std::tie(lhsMask, rhsMask) = *operandMasks;
412 rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
413 outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
418 static FailureOr<std::pair<Value, Value>>
422 auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>();
426 auto maskType = createMaskOp.getVectorType();
427 Value lhsMaskDim = createMaskOp.getOperand(0);
428 Value rhsMaskDim = createMaskOp.getOperand(1);
431 Value lhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
433 Value rhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
436 return std::make_pair(lhsMask, rhsMask);
452 struct VectorExtractToArmSMELowering
456 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
458 VectorType sourceType = extractOp.getSourceVectorType();
462 auto loc = extractOp.getLoc();
463 auto position = extractOp.getMixedPosition();
465 Value sourceVector = extractOp.getVector();
468 if (position.empty()) {
469 rewriter.
replaceOp(extractOp, sourceVector);
474 auto extractTileSlice = arm_sme::ExtractTileSliceOp::create(
475 rewriter, loc, sourceVector, sliceIndex);
477 if (position.size() == 1) {
479 rewriter.
replaceOp(extractOp, extractTileSlice);
484 assert(position.size() == 2);
508 struct VectorInsertToArmSMELowering
512 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
514 VectorType resultType = insertOp.getResult().getType();
519 auto loc = insertOp.getLoc();
520 auto position = insertOp.getMixedPosition();
522 Value source = insertOp.getValueToStore();
526 if (position.empty()) {
531 Value tileSlice = source;
533 if (position.size() == 2) {
536 tileSlice = arm_sme::ExtractTileSliceOp::create(
537 rewriter, loc, insertOp.getDest(), sliceIndex);
538 tileSlice = vector::InsertOp::create(rewriter, loc, source, tileSlice,
544 insertOp, tileSlice, insertOp.getDest(), sliceIndex);
570 struct VectorPrintToArmSMELowering :
public OpRewritePattern<vector::PrintOp> {
573 LogicalResult matchAndRewrite(vector::PrintOp
printOp,
578 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
585 auto vscale = vector::VectorScaleOp::create(rewriter, loc);
589 auto upperBound = arith::MulIOp::create(rewriter, loc, minTileRows, vscale);
592 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
597 Value rowIndex = forOp.getInductionVar();
598 auto tileSlice = arm_sme::ExtractTileSliceOp::create(
599 rewriter, loc,
printOp.getSource(), rowIndex);
601 vector::PrintOp::create(rewriter, loc, tileSlice,
624 struct FoldTransferWriteOfExtractTileSlice
628 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
630 if (!isa<MemRefType>(writeOp.getBase().getType()))
633 if (writeOp.hasOutOfBoundsDim())
635 "not inbounds transfer write");
637 auto extractTileSlice =
638 writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>();
639 if (!extractTileSlice)
641 writeOp,
"vector to store not from ExtractTileSliceOp");
646 "unsupported permutation map");
648 Value mask = writeOp.getMask();
650 auto maskType = writeOp.getVectorType().clone(rewriter.
getI1Type());
651 mask = arith::ConstantOp::create(rewriter, writeOp.getLoc(), maskType,
656 writeOp, extractTileSlice.getTile(),
657 extractTileSlice.getTileSliceIndex(), mask, writeOp.getBase(),
658 writeOp.getIndices(), extractTileSlice.getLayout());
680 struct ExtractFromCreateMaskToPselLowering
684 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
686 if (extractOp.getNumIndices() != 1)
689 auto resultType = extractOp.getResult().getType();
690 auto resultVectorType = dyn_cast<VectorType>(resultType);
691 if (!resultVectorType)
695 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
699 auto maskType = createMaskOp.getVectorType();
700 if (maskType.getRank() != 2 || !maskType.allDimsScalable())
703 auto isSVEPredicateSize = [](int64_t size) {
704 return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
707 auto rowsBaseSize = maskType.getDimSize(0);
708 auto colsBaseSize = maskType.getDimSize(1);
709 if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
711 createMaskOp,
"mask dimensions not SVE predicate-sized");
713 auto loc = extractOp.getLoc();
720 auto rowMask = vector::CreateMaskOp::create(rewriter, loc, rowMaskType,
721 createMaskOp.getOperand(0));
722 auto colMask = vector::CreateMaskOp::create(rewriter, loc, colMaskType,
723 createMaskOp.getOperand(1));
739 LogicalResult matchAndRewrite(vector::SplatOp splatOp,
752 patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
753 TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
754 TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
755 VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
756 VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
757 VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
758 ExtractFromCreateMaskToPselLowering>(&ctx);
union mlir::linalg::@1224::ArityGroupAndKind::Kind kind
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
IntegerAttr getIndexAttr(int64_t value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx)
Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...