15 #include "llvm/Support/Casting.h"
44 struct TransferReadToArmSMELowering
48 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
51 if (transferReadOp.getTransferRank() != 2)
52 return rewriter.notifyMatchFailure(transferReadOp,
53 "not a 2 result permutation map");
55 auto vectorType = transferReadOp.getVectorType();
57 return rewriter.notifyMatchFailure(transferReadOp,
58 "not a valid vector type for SME");
60 if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
61 return rewriter.notifyMatchFailure(transferReadOp,
"not a memref source");
64 if (transferReadOp.hasOutOfBoundsDim())
65 return rewriter.notifyMatchFailure(transferReadOp,
66 "not inbounds transfer read");
68 arm_sme::TileSliceLayout layout;
71 bindDims(transferReadOp.getContext(), d0, d1);
74 layout = arm_sme::TileSliceLayout::Horizontal;
76 transferReadOp.getContext()))
77 layout = arm_sme::TileSliceLayout::Vertical;
79 return rewriter.notifyMatchFailure(transferReadOp,
80 "unsupported permutation map");
85 auto mask = transferReadOp.getMask();
86 auto padding = mask ? transferReadOp.getPadding() :
nullptr;
87 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
88 transferReadOp, vectorType, transferReadOp.getSource(),
89 transferReadOp.getIndices(), padding, mask, layout);
122 struct TransferWriteToArmSMELowering
126 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
128 auto vType = writeOp.getVectorType();
132 if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
136 if (writeOp.hasOutOfBoundsDim())
137 return rewriter.notifyMatchFailure(writeOp,
138 "not inbounds transfer write");
141 bindDims(writeOp.getContext(), d0, d1);
144 writeOp.getContext()));
147 return rewriter.notifyMatchFailure(writeOp,
148 "unsupported permutation map");
150 arm_sme::TileSliceLayout layout =
151 isTranspose ? arm_sme::TileSliceLayout::Vertical
152 : arm_sme::TileSliceLayout::Horizontal;
154 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
155 writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
156 writeOp.getMask(), layout);
162 struct VectorLoadToArmSMELowering :
public OpRewritePattern<vector::LoadOp> {
171 load, load.getVectorType(), load.getBase(), load.getIndices());
178 struct VectorStoreToArmSMELowering :
public OpRewritePattern<vector::StoreOp> {
187 store, store.getValueToStore(), store.getBase(), store.getIndices());
212 struct BroadcastOpToArmSMELowering
216 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
218 auto tileType = broadcastOp.getResultVectorType();
222 auto loc = broadcastOp.getLoc();
224 auto srcType = broadcastOp.getSourceType();
225 auto srcVectorType = dyn_cast<VectorType>(srcType);
228 if (srcType.isIntOrFloat() ||
229 (srcVectorType && (srcVectorType.getRank() == 0))) {
232 broadcastOp1D = rewriter.
create<vector::BroadcastOp>(
233 loc, tileSliceType, broadcastOp.getSource());
234 }
else if (srcVectorType && (srcVectorType.getRank() == 1))
236 broadcastOp1D = broadcastOp.getSource();
240 auto initTile = rewriter.
create<arm_sme::GetTileOp>(loc, tileType);
246 auto nextTile = b.
create<arm_sme::MoveVectorToTileSliceOp>(
247 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
255 rewriter.
replaceOp(broadcastOp, forOp.getResult(0));
285 auto tileType = splatOp.getResult().getType();
289 auto loc = splatOp.getLoc();
290 auto srcType = splatOp.getOperand().getType();
292 assert(srcType.isIntOrFloat() &&
"Invalid source type for vector.splat");
298 Value broadcastOp1D = rewriter.
create<vector::BroadcastOp>(
299 loc, tileSliceType, splatOp.getInput());
301 auto initTile = rewriter.
create<arm_sme::GetTileOp>(loc, tileType);
305 auto nextTile = b.
create<arm_sme::MoveVectorToTileSliceOp>(
306 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
315 rewriter.
replaceOp(splatOp, forOp.getResult(0));
343 struct TransposeOpToArmSMELowering
347 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
349 auto tileType = transposeOp.getResultVectorType();
355 if (permutation[0] != 1 || permutation[1] != 0)
358 auto loc = transposeOp.getLoc();
363 Value minTileSlices = rewriter.
create<arith::ConstantOp>(
367 Value numTileSlices =
368 rewriter.
create<arith::MulIOp>(loc, vscale, minTileSlices);
371 tileType.getElementType());
372 auto buffer = rewriter.
create<memref::AllocaOp>(
373 loc, bufferType,
ValueRange{numTileSlices, numTileSlices});
375 Value input = transposeOp.getVector();
378 auto tileStoreOp = rewriter.
create<arm_sme::TileStoreOp>(
383 transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
384 arm_sme::TileSliceLayout::Vertical);
423 struct VectorOuterProductToArmSMELowering
428 LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
433 if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
435 "AXPY operations not supported");
438 outerProductOp.getResultVectorType()))
440 outerProductOp,
"outer product does not fit into SME tile");
442 auto kind = outerProductOp.getKind();
443 if (kind != vector::CombiningKind::ADD)
446 "unsupported kind (lowering to SME only supports ADD at the moment)");
451 auto loc = outerProductOp.
getLoc();
452 if (outerProductOp.isMasked()) {
453 auto maskOp = outerProductOp.getMaskingOp();
456 auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
459 std::tie(lhsMask, rhsMask) = *operandMasks;
463 rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
464 outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
473 auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>();
477 auto maskType = createMaskOp.getVectorType();
478 Value lhsMaskDim = createMaskOp.getOperand(0);
479 Value rhsMaskDim = createMaskOp.getOperand(1);
483 rewriter.
create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
485 rewriter.
create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
487 return std::make_pair(lhsMask, rhsMask);
503 struct VectorExtractToArmSMELowering
509 VectorType sourceType = extractOp.getSourceVectorType();
513 auto loc = extractOp.getLoc();
514 auto position = extractOp.getMixedPosition();
516 Value sourceVector = extractOp.getVector();
519 if (position.empty()) {
520 rewriter.
replaceOp(extractOp, sourceVector);
525 auto moveTileSliceToVector =
526 rewriter.
create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
529 if (position.size() == 1) {
531 rewriter.
replaceOp(extractOp, moveTileSliceToVector);
536 assert(position.size() == 2);
538 extractOp, moveTileSliceToVector, position[1]);
560 struct VectorInsertToArmSMELowering
566 VectorType resultType = insertOp.getResult().getType();
571 auto loc = insertOp.getLoc();
572 auto position = insertOp.getMixedPosition();
574 Value source = insertOp.getSource();
578 if (position.empty()) {
583 Value tileSlice = source;
585 if (position.size() == 2) {
588 tileSlice = rewriter.
create<arm_sme::MoveTileSliceToVectorOp>(
589 loc, insertOp.getDest(), sliceIndex);
590 tileSlice = rewriter.
create<vector::InsertOp>(loc, source, tileSlice,
596 insertOp, tileSlice, insertOp.getDest(), sliceIndex);
622 struct VectorPrintToArmSMELowering :
public OpRewritePattern<vector::PrintOp> {
630 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
637 auto vscale = rewriter.
create<vector::VectorScaleOp>(loc);
639 rewriter.
create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
640 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
641 auto upperBound = rewriter.
create<arith::MulIOp>(loc, minTileRows, vscale);
642 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
643 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
648 Value rowIndex = forOp.getInductionVar();
649 auto tileSlice = rewriter.
create<arm_sme::MoveTileSliceToVectorOp>(
650 loc,
printOp.getSource(), rowIndex);
652 rewriter.
create<vector::PrintOp>(loc, tileSlice,
665 patterns.
add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
666 TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
667 TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
668 VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
669 VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
670 VectorPrintToArmSMELowering>(&ctx);
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
IntegerAttr getIndexAttr(int64_t value)
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx)
Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM intrinsics.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...