22 #define GEN_PASS_DEF_CONVERTARMSMETOSCFPASS
23 #include "mlir/Conversion/Passes.h.inc"
36 assert((rank == 1 || rank == 2) &&
"memref has unexpected rank!");
39 auto tileSliceOffset = tileSliceIndex;
42 rewriter.
create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
44 auto baseIndexPlusTileSliceOffset =
45 rewriter.
create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
46 outIndices.push_back(baseIndexPlusTileSliceOffset);
49 outIndices.push_back(indices[1]);
55 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
61 PatternRewriter::InsertionGuard guard(rewriter);
63 auto minTileSlices = rewriter.
create<arith::ConstantIndexOp>(
74 rewriter.
create<arith::MulIOp>(loc, minTileSlices, vscale);
79 auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>();
81 auto maskDim1 = createMaskOp.getOperands()[1];
86 auto numRowI64 = rewriter.
create<arith::IndexCastOp>(
88 auto numTileSlicesI64 = rewriter.
create<arith::IndexCastOp>(
91 rewriter.
create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
92 upperBound = rewriter.
create<arith::IndexCastOp>(
96 rewriter.
create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
98 upperBound = numTileSlices;
100 predicate = rewriter.
create<arith::ConstantOp>(
104 bool hasCarriedArgs = bool(initTile);
105 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
106 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
107 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step,
112 Value tileSliceIndex = forOp.getInductionVar();
114 auto adjustedIndices = getMemrefIndices(
115 memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
116 auto nextTile = makeLoopBody(
117 tileSliceIndex, adjustedIndices, predicate,
118 hasCarriedArgs ? forOp.getRegionIterArg(0) :
Value{});
120 assert(
bool(nextTile) == hasCarriedArgs);
122 rewriter.
create<scf::YieldOp>(loc, nextTile);
127 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
132 return createLoadStoreForOverTileSlices(
133 rewriter, loc, tileType, memrefIndices, memrefRank, mask,
Value{},
136 makeLoopBody(index, adjustedIndices, predicate);
171 struct TileLoadOpConversion :
public OpRewritePattern<arm_sme::TileLoadOp> {
174 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
176 auto loc = tileLoadOp.getLoc();
177 auto tileType = tileLoadOp.getVectorType();
178 auto mask = tileLoadOp.getMask();
184 loc,
"unsupported mask op, only 'vector.create_mask' is "
185 "currently supported");
186 auto padOp = tileLoadOp.getPadding();
187 assert(padOp &&
"expected padding when masking!");
189 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
190 if (!constPadOp || constPadOp.getValue() !=
193 tileLoadOp,
"op has non-zero pad, needs non-zero pad pattern");
198 initTile = rewriter.
create<arm_sme::ZeroOp>(loc, tileType);
200 initTile = rewriter.
create<arm_sme::GetTileOp>(loc, tileType);
204 auto forOp = createLoadStoreForOverTileSlices(
205 rewriter, loc, tileType, tileLoadOp.getIndices(),
206 tileLoadOp.getMemRefType().getRank(), mask, initTile,
211 return rewriter.create<arm_sme::LoadTileSliceOp>(
212 loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
213 memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
220 rewriter.
replaceOp(tileLoadOp, forOp->getResult(0));
253 struct TileLoadOpWithMaskAndPadNonZeroConversion
257 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
260 auto loc = tileLoadOp.getLoc();
261 auto tileType = tileLoadOp.getVectorType();
262 auto tileElementType = tileType.getElementType();
264 auto maskOp = tileLoadOp.getMask();
267 tileLoadOp,
"op has no mask, needs unmasked pattern");
269 auto padOp = tileLoadOp.getPadding();
270 assert(padOp &&
"expected padding when masking!");
272 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
275 tileLoadOp,
"unsupported mask op, only 'vector.create_mask' is "
276 "currently supported");
278 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
280 constPadOp.getValue() == rewriter.
getZeroAttr(tileElementType))
282 tileLoadOp,
"op has constant zero pad, needs zero pad pattern");
284 auto numRows = createMaskOp.getOperands()[0];
285 auto numCols = createMaskOp.getOperands()[1];
287 auto numColsI32 = rewriter.
create<arith::IndexCastUIOp>(
290 auto initTile = rewriter.
create<arm_sme::GetTileOp>(loc, tileType);
293 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
294 auto minTileSlices = rewriter.
create<arith::ConstantIndexOp>(
298 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
300 rewriter.
create<arith::MulIOp>(loc, minTileSlices, vscale);
301 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, numTileSlices,
306 auto tileSliceIndex = forOp.getInductionVar();
307 auto currentTile = forOp.getRegionIterArg(0);
310 auto rowIsActive = rewriter.
create<arith::CmpIOp>(
311 loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
312 auto rowIsActiveI32 = rewriter.
create<arith::ExtSIOp>(
314 auto mask = rewriter.
create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
319 auto maskOp1D = rewriter.
create<vector::CreateMaskOp>(
320 loc, predicateType, maskIndex.getResult());
322 auto memrefIndices = getMemrefIndices(
323 tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
324 tileSliceIndex, numTileSlices, loc, rewriter);
328 auto pad1DOp = rewriter.
create<vector::SplatOp>(loc, tileSliceType, padOp);
330 auto loadSlice = rewriter.
create<vector::MaskedLoadOp>(
331 loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
335 auto insertSlice = rewriter.
create<arm_sme::InsertTileSliceOp>(
336 loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
337 tileLoadOp.getLayout());
338 rewriter.
create<scf::YieldOp>(loc, insertSlice.getResult());
343 rewriter.
replaceOp(tileLoadOp, forOp.getResult(0));
370 struct TileStoreOpConversion :
public OpRewritePattern<arm_sme::TileStoreOp> {
373 LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
375 if (
Value mask = tileStoreOp.getMask()) {
376 if (!mask.getDefiningOp<vector::CreateMaskOp>())
378 tileStoreOp.getLoc(),
379 "unsupported mask op, only 'vector.create_mask' is "
380 "currently supported");
384 return createLoadStoreForOverTileSlices(
385 rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
386 tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
387 tileStoreOp.getMask(),
389 rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
390 tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
391 predicate, tileStoreOp.getBase(), memrefIndices,
392 tileStoreOp.getLayout());
400 patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
401 TileStoreOpConversion>(
patterns.getContext());
406 struct ConvertArmSMEToSCFPass
407 :
public impl::ConvertArmSMEToSCFPassBase<ConvertArmSMEToSCFPass> {
408 void runOnOperation()
override {
412 target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
413 arith::ArithDialect, scf::SCFDialect>();
414 target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
static MLIRContext * getContext(OpFoldResult val)
TypedAttr getZeroAttr(Type type)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Include the generated interface declarations.
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...