22#define GEN_PASS_DEF_CONVERTARMSMETOSCFPASS
23#include "mlir/Conversion/Passes.h.inc"
36 assert(rank == 2 &&
"memref has unexpected rank!");
39 auto tileSliceOffset = tileSliceIndex;
41 auto baseIndexPlusTileSliceOffset =
42 arith::AddIOp::create(rewriter, loc,
indices[0], tileSliceOffset);
43 outIndices.push_back(baseIndexPlusTileSliceOffset);
44 outIndices.push_back(
indices[1]);
50FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
59 if (memrefIndices.size() != 2)
66 vector::VectorScaleOp::create(rewriter, loc, rewriter.
getIndexType());
68 VectorType::get(tileType.getDimSize(1), rewriter.
getI1Type(),
true);
74 arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
79 auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>();
81 auto maskDim1 = createMaskOp.getOperands()[1];
86 auto numRowI64 = arith::IndexCastOp::create(
87 rewriter, loc, rewriter.
getI64Type(), maskDim0);
88 auto numTileSlicesI64 = arith::IndexCastOp::create(
89 rewriter, loc, rewriter.
getI64Type(), numTileSlices);
91 arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64);
92 upperBound = arith::IndexCastOp::create(
96 vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1);
98 upperBound = numTileSlices;
100 predicate = arith::ConstantOp::create(
104 bool hasCarriedArgs =
bool(initTile);
108 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
112 Value tileSliceIndex = forOp.getInductionVar();
114 auto adjustedIndices = getMemrefIndices(
115 memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
116 auto nextTile = makeLoopBody(
117 tileSliceIndex, adjustedIndices, predicate,
118 hasCarriedArgs ? forOp.getRegionIterArg(0) :
Value{});
120 assert(
bool(nextTile) == hasCarriedArgs);
122 scf::YieldOp::create(rewriter, loc, nextTile);
127FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
132 return createLoadStoreForOverTileSlices(
133 rewriter, loc, tileType, memrefIndices, memrefRank, mask,
Value{},
136 makeLoopBody(
index, adjustedIndices, predicate);
172 using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
174 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
175 PatternRewriter &rewriter)
const override {
176 auto loc = tileLoadOp.getLoc();
177 auto tileType = tileLoadOp.getVectorType();
178 auto mask = tileLoadOp.getMask();
184 loc,
"unsupported mask op, only 'vector.create_mask' is "
185 "currently supported");
186 auto padOp = tileLoadOp.getPadding();
187 assert(padOp &&
"expected padding when masking!");
189 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
190 if (!constPadOp || constPadOp.getValue() !=
193 tileLoadOp,
"op has non-zero pad, needs non-zero pad pattern");
198 initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType);
200 initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
204 auto forOp = createLoadStoreForOverTileSlices(
205 rewriter, loc, tileType, tileLoadOp.getIndices(),
206 tileLoadOp.getMemRefType().getRank(), mask, initTile,
207 [&](Value tileSliceIndex,
ValueRange memrefIndices, Value predicate,
208 Value currentTile) -> Value {
211 return arm_sme::LoadTileSliceOp::create(
212 rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
213 currentTile, memrefIndices, tileSliceIndex,
214 tileLoadOp.getLayout());
221 rewriter.
replaceOp(tileLoadOp, forOp->getResult(0));
254struct TileLoadOpWithMaskAndPadNonZeroConversion
256 using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
258 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
259 PatternRewriter &rewriter)
const override {
260 OpBuilder::InsertionGuard g(rewriter);
261 auto loc = tileLoadOp.getLoc();
262 auto tileType = tileLoadOp.getVectorType();
263 auto tileElementType = tileType.getElementType();
265 auto maskOp = tileLoadOp.getMask();
268 tileLoadOp,
"op has no mask, needs unmasked pattern");
270 auto padOp = tileLoadOp.getPadding();
271 assert(padOp &&
"expected padding when masking!");
273 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
276 tileLoadOp,
"unsupported mask op, only 'vector.create_mask' is "
277 "currently supported");
279 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
281 constPadOp.getValue() == rewriter.
getZeroAttr(tileElementType))
283 tileLoadOp,
"op has constant zero pad, needs zero pad pattern");
285 auto numRows = createMaskOp.getOperands()[0];
286 auto numCols = createMaskOp.getOperands()[1];
288 auto numColsI32 = arith::IndexCastUIOp::create(
289 rewriter, loc, rewriter.
getI32Type(), numCols);
291 auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
298 vector::VectorScaleOp::create(rewriter, loc, rewriter.
getIndexType());
301 arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
302 auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
307 auto tileSliceIndex = forOp.getInductionVar();
308 auto currentTile = forOp.getRegionIterArg(0);
311 auto rowIsActive = arith::CmpIOp::create(
312 rewriter, loc, arith::CmpIPredicate::slt, tileSliceIndex, numRows);
313 auto rowIsActiveI32 = arith::ExtSIOp::create(
314 rewriter, loc, rewriter.
getI32Type(), rowIsActive);
316 arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32);
317 auto maskIndex = arith::IndexCastOp::create(rewriter, loc,
320 VectorType::get(tileType.getDimSize(1), rewriter.
getI1Type(),
true);
321 auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType,
322 maskIndex.getResult());
324 auto memrefIndices = getMemrefIndices(
325 tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
326 tileSliceIndex, numTileSlices, loc, rewriter);
329 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
331 vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
333 auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
334 tileLoadOp.getBase(),
335 memrefIndices, maskOp1D,
339 auto insertSlice = arm_sme::InsertTileSliceOp::create(
340 rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
341 tileSliceIndex, tileLoadOp.getLayout());
342 scf::YieldOp::create(rewriter, loc, insertSlice.getResult());
347 rewriter.
replaceOp(tileLoadOp, forOp.getResult(0));
374struct TileStoreOpConversion :
public OpRewritePattern<arm_sme::TileStoreOp> {
375 using OpRewritePattern<arm_sme::TileStoreOp>::OpRewritePattern;
377 LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
378 PatternRewriter &rewriter)
const override {
379 if (Value mask = tileStoreOp.getMask()) {
382 tileStoreOp.getLoc(),
383 "unsupported mask op, only 'vector.create_mask' is "
384 "currently supported");
388 return createLoadStoreForOverTileSlices(
389 rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
390 tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
391 tileStoreOp.getMask(),
392 [&](Value tileSliceIndex,
ValueRange memrefIndices, Value predicate) {
393 rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
394 tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
395 predicate, tileStoreOp.getBase(), memrefIndices,
396 tileStoreOp.getLayout());
404 patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
405 TileStoreOpConversion>(
patterns.getContext());
410struct ConvertArmSMEToSCFPass
412 void runOnOperation()
override {
416 target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
417 arith::ArithDialect, scf::SCFDialect>();
418 target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
419 if (failed(applyPartialConversion(getOperation(),
target,
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Include the generated interface declarations.
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
const FrozenRewritePatternSet & patterns
llvm::function_ref< Fn > function_ref
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...