22 #define GEN_PASS_DEF_CONVERTARMSMETOSCF
23 #include "mlir/Conversion/Passes.h.inc"
36 assert((rank == 1 || rank == 2) &&
"memref has unexpected rank!");
39 auto tileSliceOffset = tileSliceIndex;
42 rewriter.
create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
44 auto baseIndexPlusTileSliceOffset =
45 rewriter.
create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
46 outIndices.push_back(baseIndexPlusTileSliceOffset);
49 outIndices.push_back(indices[1]);
55 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
61 PatternRewriter::InsertionGuard guard(rewriter);
63 auto minTileSlices = rewriter.
create<arith::ConstantIndexOp>(
74 rewriter.
create<arith::MulIOp>(loc, minTileSlices, vscale);
79 auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>();
82 loc,
"unsupported mask op, only 'vector.create_mask' is "
83 "currently supported");
85 auto maskDim0 = createMaskOp.getOperands()[0];
86 auto maskDim1 = createMaskOp.getOperands()[1];
91 auto numRowI64 = rewriter.
create<arith::IndexCastOp>(
93 auto numTileSlicesI64 = rewriter.
create<arith::IndexCastOp>(
96 rewriter.
create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
97 upperBound = rewriter.
create<arith::IndexCastOp>(
101 rewriter.
create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
103 upperBound = numTileSlices;
105 predicate = rewriter.
create<arith::ConstantOp>(
109 bool hasCarriedArgs = bool(initTile);
110 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
111 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
112 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step,
117 Value tileSliceIndex = forOp.getInductionVar();
119 auto adjustedIndices = getMemrefIndices(
120 memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
121 auto nextTile = makeLoopBody(
122 tileSliceIndex, adjustedIndices, predicate,
123 hasCarriedArgs ? forOp.getRegionIterArg(0) :
Value{});
125 assert(
bool(nextTile) == hasCarriedArgs);
127 rewriter.
create<scf::YieldOp>(loc, nextTile);
132 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
137 return createLoadStoreForOverTileSlices(
138 rewriter, loc, tileType, memrefIndices, memrefRank, mask,
Value{},
141 makeLoopBody(index, adjustedIndices, predicate);
176 struct TileLoadOpConversion :
public OpRewritePattern<arm_sme::TileLoadOp> {
179 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
181 auto loc = tileLoadOp.getLoc();
182 auto tileType = tileLoadOp.getVectorType();
183 auto mask = tileLoadOp.getMask();
187 auto padOp = tileLoadOp.getPadding();
188 assert(padOp &&
"expected padding when masking!");
190 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
191 if (!constPadOp || constPadOp.getValue() !=
194 tileLoadOp,
"op has non-zero pad, needs non-zero pad pattern");
199 initTile = rewriter.
create<arm_sme::ZeroOp>(loc, tileType);
201 initTile = rewriter.
create<arm_sme::GetTileOp>(loc, tileType);
205 auto forOp = createLoadStoreForOverTileSlices(
206 rewriter, loc, tileType, tileLoadOp.getIndices(),
207 tileLoadOp.getMemRefType().getRank(), mask, initTile,
212 return rewriter.create<arm_sme::LoadTileSliceOp>(
213 loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
214 memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
221 rewriter.
replaceOp(tileLoadOp, forOp->getResult(0));
254 struct TileLoadOpWithMaskAndPadNonZeroConversion
258 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
261 auto loc = tileLoadOp.getLoc();
262 auto tileType = tileLoadOp.getVectorType();
263 auto tileElementType = tileType.getElementType();
265 auto maskOp = tileLoadOp.getMask();
268 tileLoadOp,
"op has no mask, needs unmasked pattern");
270 auto padOp = tileLoadOp.getPadding();
271 assert(padOp &&
"expected padding when masking!");
273 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
276 tileLoadOp,
"unsupported mask op, only 'vector.create_mask' is "
277 "currently supported");
279 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
281 constPadOp.getValue() == rewriter.
getZeroAttr(tileElementType))
283 tileLoadOp,
"op has constant zero pad, needs zero pad pattern");
285 auto numRows = createMaskOp.getOperands()[0];
286 auto numCols = createMaskOp.getOperands()[1];
288 auto numColsI32 = rewriter.
create<arith::IndexCastUIOp>(
291 auto initTile = rewriter.
create<arm_sme::GetTileOp>(loc, tileType);
294 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
295 auto minTileSlices = rewriter.
create<arith::ConstantIndexOp>(
299 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
301 rewriter.
create<arith::MulIOp>(loc, minTileSlices, vscale);
302 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, numTileSlices,
307 auto tileSliceIndex = forOp.getInductionVar();
308 auto currentTile = forOp.getRegionIterArg(0);
311 auto rowIsActive = rewriter.
create<arith::CmpIOp>(
312 loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
313 auto rowIsActiveI32 = rewriter.
create<arith::ExtSIOp>(
315 auto mask = rewriter.
create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
320 auto maskOp1D = rewriter.
create<vector::CreateMaskOp>(
321 loc, predicateType, maskIndex.getResult());
323 auto memrefIndices = getMemrefIndices(
324 tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
325 tileSliceIndex, numTileSlices, loc, rewriter);
329 auto pad1DOp = rewriter.
create<vector::SplatOp>(loc, tileSliceType, padOp);
331 auto loadSlice = rewriter.
create<vector::MaskedLoadOp>(
332 loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
336 auto insertSlice = rewriter.
create<arm_sme::InsertTileSliceOp>(
337 loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
338 tileLoadOp.getLayout());
339 rewriter.
create<scf::YieldOp>(loc, insertSlice.getResult());
344 rewriter.
replaceOp(tileLoadOp, forOp.getResult(0));
371 struct TileStoreOpConversion :
public OpRewritePattern<arm_sme::TileStoreOp> {
374 LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
377 return createLoadStoreForOverTileSlices(
378 rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
379 tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
380 tileStoreOp.getMask(),
382 rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
383 tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
384 predicate, tileStoreOp.getBase(), memrefIndices,
385 tileStoreOp.getLayout());
393 patterns.
add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
394 TileStoreOpConversion>(patterns.
getContext());
399 struct ConvertArmSMEToSCFPass
400 :
public impl::ConvertArmSMEToSCFBase<ConvertArmSMEToSCFPass> {
401 void runOnOperation()
override {
405 target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
406 arith::ArithDialect, scf::SCFDialect>();
407 target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
409 std::move(patterns))))
417 return std::make_unique<ConvertArmSMEToSCFPass>();
static MLIRContext * getContext(OpFoldResult val)
TypedAttr getZeroAttr(Type type)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToSCFPass()
Create a pass to convert a subset of ArmSME ops to SCF.
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...