22 #define GEN_PASS_DEF_CONVERTARMSMETOSCFPASS
23 #include "mlir/Conversion/Passes.h.inc"
36 assert(rank == 2 &&
"memref has unexpected rank!");
39 auto tileSliceOffset = tileSliceIndex;
41 auto baseIndexPlusTileSliceOffset =
42 arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset);
43 outIndices.push_back(baseIndexPlusTileSliceOffset);
44 outIndices.push_back(indices[1]);
50 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
56 PatternRewriter::InsertionGuard guard(rewriter);
59 if (memrefIndices.size() != 2)
66 vector::VectorScaleOp::create(rewriter, loc, rewriter.
getIndexType());
74 arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
79 auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>();
81 auto maskDim1 = createMaskOp.getOperands()[1];
86 auto numRowI64 = arith::IndexCastOp::create(
87 rewriter, loc, rewriter.
getI64Type(), maskDim0);
88 auto numTileSlicesI64 = arith::IndexCastOp::create(
89 rewriter, loc, rewriter.
getI64Type(), numTileSlices);
91 arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64);
92 upperBound = arith::IndexCastOp::create(
96 vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1);
98 upperBound = numTileSlices;
100 predicate = arith::ConstantOp::create(
104 bool hasCarriedArgs = bool(initTile);
108 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
112 Value tileSliceIndex = forOp.getInductionVar();
114 auto adjustedIndices = getMemrefIndices(
115 memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
116 auto nextTile = makeLoopBody(
117 tileSliceIndex, adjustedIndices, predicate,
118 hasCarriedArgs ? forOp.getRegionIterArg(0) :
Value{});
120 assert(
bool(nextTile) == hasCarriedArgs);
122 scf::YieldOp::create(rewriter, loc, nextTile);
127 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
132 return createLoadStoreForOverTileSlices(
133 rewriter, loc, tileType, memrefIndices, memrefRank, mask,
Value{},
136 makeLoopBody(index, adjustedIndices, predicate);
171 struct TileLoadOpConversion :
public OpRewritePattern<arm_sme::TileLoadOp> {
174 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
176 auto loc = tileLoadOp.getLoc();
177 auto tileType = tileLoadOp.getVectorType();
178 auto mask = tileLoadOp.getMask();
184 loc,
"unsupported mask op, only 'vector.create_mask' is "
185 "currently supported");
186 auto padOp = tileLoadOp.getPadding();
187 assert(padOp &&
"expected padding when masking!");
189 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
190 if (!constPadOp || constPadOp.getValue() !=
193 tileLoadOp,
"op has non-zero pad, needs non-zero pad pattern");
198 initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType);
200 initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
204 auto forOp = createLoadStoreForOverTileSlices(
205 rewriter, loc, tileType, tileLoadOp.getIndices(),
206 tileLoadOp.getMemRefType().getRank(), mask, initTile,
211 return arm_sme::LoadTileSliceOp::create(
212 rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
213 currentTile, memrefIndices, tileSliceIndex,
214 tileLoadOp.getLayout());
221 rewriter.
replaceOp(tileLoadOp, forOp->getResult(0));
254 struct TileLoadOpWithMaskAndPadNonZeroConversion
258 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
261 auto loc = tileLoadOp.getLoc();
262 auto tileType = tileLoadOp.getVectorType();
263 auto tileElementType = tileType.getElementType();
265 auto maskOp = tileLoadOp.getMask();
268 tileLoadOp,
"op has no mask, needs unmasked pattern");
270 auto padOp = tileLoadOp.getPadding();
271 assert(padOp &&
"expected padding when masking!");
273 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
276 tileLoadOp,
"unsupported mask op, only 'vector.create_mask' is "
277 "currently supported");
279 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
281 constPadOp.getValue() == rewriter.
getZeroAttr(tileElementType))
283 tileLoadOp,
"op has constant zero pad, needs zero pad pattern");
285 auto numRows = createMaskOp.getOperands()[0];
286 auto numCols = createMaskOp.getOperands()[1];
288 auto numColsI32 = arith::IndexCastUIOp::create(
289 rewriter, loc, rewriter.
getI32Type(), numCols);
291 auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
298 vector::VectorScaleOp::create(rewriter, loc, rewriter.
getIndexType());
301 arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
302 auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
307 auto tileSliceIndex = forOp.getInductionVar();
308 auto currentTile = forOp.getRegionIterArg(0);
311 auto rowIsActive = arith::CmpIOp::create(
312 rewriter, loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
313 auto rowIsActiveI32 = arith::ExtSIOp::create(
314 rewriter, loc, rewriter.
getI32Type(), rowIsActive);
316 arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32);
317 auto maskIndex = arith::IndexCastOp::create(rewriter, loc,
321 auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType,
322 maskIndex.getResult());
324 auto memrefIndices = getMemrefIndices(
325 tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
326 tileSliceIndex, numTileSlices, loc, rewriter);
331 vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
333 auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
334 tileLoadOp.getBase(),
335 memrefIndices, maskOp1D,
339 auto insertSlice = arm_sme::InsertTileSliceOp::create(
340 rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
341 tileSliceIndex, tileLoadOp.getLayout());
342 scf::YieldOp::create(rewriter, loc, insertSlice.getResult());
347 rewriter.
replaceOp(tileLoadOp, forOp.getResult(0));
374 struct TileStoreOpConversion :
public OpRewritePattern<arm_sme::TileStoreOp> {
377 LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
379 if (
Value mask = tileStoreOp.getMask()) {
382 tileStoreOp.getLoc(),
383 "unsupported mask op, only 'vector.create_mask' is "
384 "currently supported");
388 return createLoadStoreForOverTileSlices(
389 rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
390 tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
391 tileStoreOp.getMask(),
393 rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
394 tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
395 predicate, tileStoreOp.getBase(), memrefIndices,
396 tileStoreOp.getLayout());
404 patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
405 TileStoreOpConversion>(
patterns.getContext());
410 struct ConvertArmSMEToSCFPass
411 :
public impl::ConvertArmSMEToSCFPassBase<ConvertArmSMEToSCFPass> {
412 void runOnOperation()
override {
416 target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
417 arith::ArithDialect, scf::SCFDialect>();
418 target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
static MLIRContext * getContext(OpFoldResult val)
TypedAttr getZeroAttr(Type type)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Include the generated interface declarations.
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...