22 #define GEN_PASS_DEF_CONVERTARMSMETOSCFPASS
23 #include "mlir/Conversion/Passes.h.inc"
36 assert(rank == 2 &&
"memref has unexpected rank!");
39 auto tileSliceOffset = tileSliceIndex;
41 auto baseIndexPlusTileSliceOffset =
42 rewriter.
create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
43 outIndices.push_back(baseIndexPlusTileSliceOffset);
44 outIndices.push_back(indices[1]);
50 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
56 PatternRewriter::InsertionGuard guard(rewriter);
59 if (memrefIndices.size() != 2)
62 auto minTileSlices = rewriter.
create<arith::ConstantIndexOp>(
73 rewriter.
create<arith::MulIOp>(loc, minTileSlices, vscale);
78 auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>();
80 auto maskDim1 = createMaskOp.getOperands()[1];
85 auto numRowI64 = rewriter.
create<arith::IndexCastOp>(
87 auto numTileSlicesI64 = rewriter.
create<arith::IndexCastOp>(
90 rewriter.
create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
91 upperBound = rewriter.
create<arith::IndexCastOp>(
95 rewriter.
create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
97 upperBound = numTileSlices;
99 predicate = rewriter.
create<arith::ConstantOp>(
103 bool hasCarriedArgs = bool(initTile);
104 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
105 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
106 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step,
111 Value tileSliceIndex = forOp.getInductionVar();
113 auto adjustedIndices = getMemrefIndices(
114 memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
115 auto nextTile = makeLoopBody(
116 tileSliceIndex, adjustedIndices, predicate,
117 hasCarriedArgs ? forOp.getRegionIterArg(0) :
Value{});
119 assert(
bool(nextTile) == hasCarriedArgs);
121 rewriter.
create<scf::YieldOp>(loc, nextTile);
126 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
131 return createLoadStoreForOverTileSlices(
132 rewriter, loc, tileType, memrefIndices, memrefRank, mask,
Value{},
135 makeLoopBody(index, adjustedIndices, predicate);
170 struct TileLoadOpConversion :
public OpRewritePattern<arm_sme::TileLoadOp> {
173 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
175 auto loc = tileLoadOp.getLoc();
176 auto tileType = tileLoadOp.getVectorType();
177 auto mask = tileLoadOp.getMask();
183 loc,
"unsupported mask op, only 'vector.create_mask' is "
184 "currently supported");
185 auto padOp = tileLoadOp.getPadding();
186 assert(padOp &&
"expected padding when masking!");
188 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
189 if (!constPadOp || constPadOp.getValue() !=
192 tileLoadOp,
"op has non-zero pad, needs non-zero pad pattern");
197 initTile = rewriter.
create<arm_sme::ZeroOp>(loc, tileType);
199 initTile = rewriter.
create<arm_sme::GetTileOp>(loc, tileType);
203 auto forOp = createLoadStoreForOverTileSlices(
204 rewriter, loc, tileType, tileLoadOp.getIndices(),
205 tileLoadOp.getMemRefType().getRank(), mask, initTile,
210 return rewriter.create<arm_sme::LoadTileSliceOp>(
211 loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
212 memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
219 rewriter.
replaceOp(tileLoadOp, forOp->getResult(0));
252 struct TileLoadOpWithMaskAndPadNonZeroConversion
256 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
259 auto loc = tileLoadOp.getLoc();
260 auto tileType = tileLoadOp.getVectorType();
261 auto tileElementType = tileType.getElementType();
263 auto maskOp = tileLoadOp.getMask();
266 tileLoadOp,
"op has no mask, needs unmasked pattern");
268 auto padOp = tileLoadOp.getPadding();
269 assert(padOp &&
"expected padding when masking!");
271 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
274 tileLoadOp,
"unsupported mask op, only 'vector.create_mask' is "
275 "currently supported");
277 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
279 constPadOp.getValue() == rewriter.
getZeroAttr(tileElementType))
281 tileLoadOp,
"op has constant zero pad, needs zero pad pattern");
283 auto numRows = createMaskOp.getOperands()[0];
284 auto numCols = createMaskOp.getOperands()[1];
286 auto numColsI32 = rewriter.
create<arith::IndexCastUIOp>(
289 auto initTile = rewriter.
create<arm_sme::GetTileOp>(loc, tileType);
292 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
293 auto minTileSlices = rewriter.
create<arith::ConstantIndexOp>(
297 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
299 rewriter.
create<arith::MulIOp>(loc, minTileSlices, vscale);
300 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, numTileSlices,
305 auto tileSliceIndex = forOp.getInductionVar();
306 auto currentTile = forOp.getRegionIterArg(0);
309 auto rowIsActive = rewriter.
create<arith::CmpIOp>(
310 loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
311 auto rowIsActiveI32 = rewriter.
create<arith::ExtSIOp>(
313 auto mask = rewriter.
create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
318 auto maskOp1D = rewriter.
create<vector::CreateMaskOp>(
319 loc, predicateType, maskIndex.getResult());
321 auto memrefIndices = getMemrefIndices(
322 tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
323 tileSliceIndex, numTileSlices, loc, rewriter);
327 auto pad1DOp = rewriter.
create<vector::SplatOp>(loc, tileSliceType, padOp);
329 auto loadSlice = rewriter.
create<vector::MaskedLoadOp>(
330 loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
334 auto insertSlice = rewriter.
create<arm_sme::InsertTileSliceOp>(
335 loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
336 tileLoadOp.getLayout());
337 rewriter.
create<scf::YieldOp>(loc, insertSlice.getResult());
342 rewriter.
replaceOp(tileLoadOp, forOp.getResult(0));
369 struct TileStoreOpConversion :
public OpRewritePattern<arm_sme::TileStoreOp> {
372 LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
374 if (
Value mask = tileStoreOp.getMask()) {
375 if (!mask.getDefiningOp<vector::CreateMaskOp>())
377 tileStoreOp.getLoc(),
378 "unsupported mask op, only 'vector.create_mask' is "
379 "currently supported");
383 return createLoadStoreForOverTileSlices(
384 rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
385 tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
386 tileStoreOp.getMask(),
388 rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
389 tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
390 predicate, tileStoreOp.getBase(), memrefIndices,
391 tileStoreOp.getLayout());
399 patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
400 TileStoreOpConversion>(
patterns.getContext());
405 struct ConvertArmSMEToSCFPass
406 :
public impl::ConvertArmSMEToSCFPassBase<ConvertArmSMEToSCFPass> {
407 void runOnOperation()
override {
411 target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
412 arith::ArithDialect, scf::SCFDialect>();
413 target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
static MLIRContext * getContext(OpFoldResult val)
TypedAttr getZeroAttr(Type type)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Include the generated interface declarations.
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...