15 #include "llvm/Support/Casting.h"
21 if (llvm::isa<FloatType>(elemType))
23 if (llvm::isa<IntegerType>(elemType))
33 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
34 auto minTileSlices = rewriter.
create<arith::ConstantIndexOp>(
38 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
40 rewriter.
create<arith::MulIOp>(loc, minTileSlices, vscale);
42 rewriter.
create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
48 static arm_sme::CastTileToVector
51 unsigned tileElementWidth = type.getElementType().getIntOrFloatBitWidth();
54 auto tileId = rewriter.
create<arm_sme::GetTileID>(
58 return rewriter.
create<arm_sme::CastTileToVector>(loc, type, tileId);
86 struct TransferReadToArmSMELowering
90 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
93 if (transferReadOp.getTransferRank() != 2)
94 return rewriter.notifyMatchFailure(transferReadOp,
95 "not a 2 result permutation map");
97 auto vectorType = transferReadOp.getVectorType();
99 return rewriter.notifyMatchFailure(transferReadOp,
100 "not a valid vector type for SME");
102 if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
103 return rewriter.notifyMatchFailure(transferReadOp,
"not a memref source");
106 if (transferReadOp.hasOutOfBoundsDim())
107 return rewriter.notifyMatchFailure(transferReadOp,
108 "not inbounds transfer read");
110 arm_sme::TileSliceLayout layout;
113 bindDims(transferReadOp.getContext(), d0, d1);
116 layout = arm_sme::TileSliceLayout::Horizontal;
118 transferReadOp.getContext()))
119 layout = arm_sme::TileSliceLayout::Vertical;
121 return rewriter.notifyMatchFailure(transferReadOp,
122 "unsupported permutation map");
127 auto mask = transferReadOp.getMask();
128 auto padding = mask ? transferReadOp.getPadding() :
nullptr;
129 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
130 transferReadOp, vectorType, transferReadOp.getSource(),
131 transferReadOp.getIndices(), padding, mask, layout);
164 struct TransferWriteToArmSMELowering
168 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
170 auto vType = writeOp.getVectorType();
174 if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
178 if (writeOp.hasOutOfBoundsDim())
179 return rewriter.notifyMatchFailure(writeOp,
180 "not inbounds transfer write");
183 bindDims(writeOp.getContext(), d0, d1);
186 writeOp.getContext()));
189 return rewriter.notifyMatchFailure(writeOp,
190 "unsupported permutation map");
192 arm_sme::TileSliceLayout layout =
193 isTranspose ? arm_sme::TileSliceLayout::Vertical
194 : arm_sme::TileSliceLayout::Horizontal;
196 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
197 writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
198 writeOp.getMask(), layout);
204 struct VectorLoadToArmSMELowering :
public OpRewritePattern<vector::LoadOp> {
213 load, load.getVectorType(), load.getBase(), load.getIndices());
220 struct VectorStoreToArmSMELowering :
public OpRewritePattern<vector::StoreOp> {
229 store, store.getValueToStore(), store.getBase(), store.getIndices());
236 struct ConstantOpToArmSMELowering :
public OpRewritePattern<arith::ConstantOp> {
241 auto tileType = dyn_cast<VectorType>(constantOp.getType());
245 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
246 if (!denseAttr || !denseAttr.isSplat())
249 auto tileElementType = tileType.getElementType();
260 auto loc = constantOp.getLoc();
267 tileSliceType, denseAttr.getSplatValue<
Attribute>());
268 auto constantOp1D = rewriter.
create<arith::ConstantOp>(loc, denseAttr1D);
270 arm_sme::CastTileToVector
tile =
274 auto tileSliceIndex = forOp.getInductionVar();
277 rewriter.
create<arm_sme::MoveVectorToTileSliceOp>(
278 loc, tileType, constantOp1D,
tile, tileSliceIndex);
301 struct BroadcastOpToArmSMELowering
305 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
307 auto tileType = broadcastOp.getResultVectorType();
312 auto loc = broadcastOp.getLoc();
314 auto srcType = broadcastOp.getSourceType();
315 auto srcVectorType = dyn_cast<VectorType>(srcType);
316 auto tileElementType = tileType.getElementType();
319 if (srcType.isIntOrFloat() ||
320 (srcVectorType && (srcVectorType.getRank() == 0))) {
325 broadcastOp1D = rewriter.
create<vector::BroadcastOp>(
326 loc, tileSliceType, broadcastOp.getSource());
327 }
else if (srcVectorType && (srcVectorType.getRank() == 1))
329 broadcastOp1D = broadcastOp.getSource();
333 arm_sme::CastTileToVector
tile =
338 auto tileSliceIndex = forOp.getInductionVar();
342 rewriter.
create<arm_sme::MoveVectorToTileSliceOp>(
343 loc, tileType, broadcastOp1D,
tile, tileSliceIndex);
371 auto tileType = splatOp.getResult().getType();
376 auto loc = splatOp.getLoc();
378 auto srcType = splatOp.getOperand().getType();
379 auto tileElementType = tileType.getElementType();
381 assert(srcType.isIntOrFloat() &&
"Invalid source type for vector.splat");
387 Value broadcastOp1D = rewriter.
create<vector::BroadcastOp>(
388 loc, tileSliceType, splatOp.getInput());
390 arm_sme::CastTileToVector
tile =
396 auto tileSliceIndex = forOp.getInductionVar();
398 rewriter.
create<arm_sme::MoveVectorToTileSliceOp>(
399 loc, tileType, broadcastOp1D,
tile, tileSliceIndex);
429 struct TransposeOpToArmSMELowering
433 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
435 auto tileType = transposeOp.getResultVectorType();
441 if (permutation[0] != 1 || permutation[1] != 0)
445 auto loc = transposeOp.getLoc();
450 Value minTileSlices = rewriter.
create<arith::ConstantOp>(
454 Value numTileSlices =
455 rewriter.
create<arith::MulIOp>(loc, vscale, minTileSlices);
458 tileType.getElementType());
459 auto buffer = rewriter.
create<memref::AllocaOp>(
460 loc, bufferType,
ValueRange{numTileSlices, numTileSlices});
462 Value input = transposeOp.getVector();
465 auto tileStoreOp = rewriter.
create<arm_sme::TileStoreOp>(
470 transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
471 arm_sme::TileSliceLayout::Vertical);
510 struct VectorOuterProductToArmSMELowering
515 LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
520 if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
521 return outerProductOp.emitError(
"AXPY operations not supported");
524 outerProductOp.getResultVectorType()))
525 return outerProductOp.emitError(
526 "outer product does not fit into SME tile");
528 auto kind = outerProductOp.getKind();
529 if (kind != vector::CombiningKind::ADD)
530 return outerProductOp.emitError(
531 "unsupported kind (lowering to SME only supports ADD at the moment)");
536 auto loc = outerProductOp.
getLoc();
537 if (outerProductOp.isMasked()) {
538 auto maskOp = outerProductOp.getMaskingOp();
541 auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
544 std::tie(lhsMask, rhsMask) = *operandMasks;
548 rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
549 outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
558 auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>();
562 auto maskType = createMaskOp.getVectorType();
563 Value lhsMaskDim = createMaskOp.getOperand(0);
564 Value rhsMaskDim = createMaskOp.getOperand(1);
568 rewriter.
create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
570 rewriter.
create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
572 return std::make_pair(lhsMask, rhsMask);
588 struct VectorExtractToArmSMELowering
594 VectorType sourceType = extractOp.getSourceVectorType();
598 auto loc = extractOp.getLoc();
599 auto position = extractOp.getMixedPosition();
601 Value sourceVector = extractOp.getVector();
604 if (position.empty()) {
605 rewriter.
replaceOp(extractOp, sourceVector);
610 auto moveTileSliceToVector =
611 rewriter.
create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
614 if (position.size() == 1) {
616 rewriter.
replaceOp(extractOp, moveTileSliceToVector);
621 assert(position.size() == 2);
623 extractOp, moveTileSliceToVector, position[1]);
645 struct VectorInsertToArmSMELowering
651 VectorType resultType = insertOp.getResult().getType();
656 auto loc = insertOp.getLoc();
657 auto position = insertOp.getMixedPosition();
659 Value source = insertOp.getSource();
663 if (position.empty()) {
668 Value tileSlice = source;
670 if (position.size() == 2) {
673 tileSlice = rewriter.
create<arm_sme::MoveTileSliceToVectorOp>(
674 loc, insertOp.getDest(), sliceIndex);
675 tileSlice = rewriter.
create<vector::InsertOp>(loc, source, tileSlice,
681 insertOp, tileSlice, insertOp.getDest(), sliceIndex);
690 patterns.
add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
691 SplatOpToArmSMELowering, TransferReadToArmSMELowering,
692 TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
693 VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
694 VectorOuterProductToArmSMELowering,
695 VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
static arm_sme::CastTileToVector getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc, VectorType type)
Returns a tile of the given vector type.
static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Type eltType)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index.
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx)
Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM intrinsics.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...