22 #define GEN_PASS_DEF_CONVERTARMSMETOSCF
23 #include "mlir/Conversion/Passes.h.inc"
36 assert((rank == 1 || rank == 2) &&
"memref has unexpected rank!");
39 auto tileSliceOffset = tileSliceIndex;
42 rewriter.
create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
44 auto baseIndexPlusTileSliceOffset =
45 rewriter.
create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
46 outIndices.push_back(baseIndexPlusTileSliceOffset);
49 outIndices.push_back(indices[1]);
61 PatternRewriter::InsertionGuard guard(rewriter);
63 auto minTileSlices = rewriter.
create<arith::ConstantIndexOp>(
74 rewriter.
create<arith::MulIOp>(loc, minTileSlices, vscale);
79 auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>();
82 loc,
"unsupported mask op, only 'vector.create_mask' is "
83 "currently supported");
85 auto maskDim0 = createMaskOp.getOperands()[0];
86 auto maskDim1 = createMaskOp.getOperands()[1];
91 auto numRowI64 = rewriter.
create<arith::IndexCastOp>(
93 auto numTileSlicesI64 = rewriter.
create<arith::IndexCastOp>(
96 rewriter.
create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
97 upperBound = rewriter.
create<arith::IndexCastOp>(
101 rewriter.
create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
103 upperBound = numTileSlices;
105 predicate = rewriter.
create<arith::ConstantOp>(
109 bool hasCarriedArgs = bool(initTile);
110 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
111 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
112 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step,
117 Value tileSliceIndex = forOp.getInductionVar();
119 auto adjustedIndices = getMemrefIndices(
120 memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
121 auto nextTile = makeLoopBody(
122 tileSliceIndex, adjustedIndices, predicate,
123 hasCarriedArgs ? forOp.getRegionIterArg(0) :
Value{});
125 assert(
bool(nextTile) == hasCarriedArgs);
127 rewriter.
create<scf::YieldOp>(loc, nextTile);
137 return createLoadStoreForOverTileSlices(
138 rewriter, loc, tileType, memrefIndices, memrefRank, mask,
Value{},
141 makeLoopBody(index, adjustedIndices, predicate);
176 struct TileLoadOpConversion :
public OpRewritePattern<arm_sme::TileLoadOp> {
179 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
181 auto loc = tileLoadOp.getLoc();
182 auto tileType = tileLoadOp.getVectorType();
183 auto mask = tileLoadOp.getMask();
187 auto padOp = tileLoadOp.getPadding();
188 assert(padOp &&
"expected padding when masking!");
190 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
191 if (!constPadOp || constPadOp.getValue() !=
194 tileLoadOp,
"op has non-zero pad, needs non-zero pad pattern");
199 initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
200 rewriter, loc, tileType);
203 initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
204 rewriter, loc, tileType);
208 auto forOp = createLoadStoreForOverTileSlices(
209 rewriter, loc, tileType, tileLoadOp.getIndices(),
210 tileLoadOp.getMemRefType().getRank(), mask, initTile,
215 return tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
216 rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
217 currentTile, memrefIndices, tileSliceIndex,
218 tileLoadOp.getLayout());
225 rewriter.
replaceOp(tileLoadOp, forOp->getResult(0));
258 struct TileLoadOpWithMaskAndPadNonZeroConversion
262 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
265 auto loc = tileLoadOp.getLoc();
266 auto tileType = tileLoadOp.getVectorType();
267 auto tileElementType = tileType.getElementType();
269 auto maskOp = tileLoadOp.getMask();
272 tileLoadOp,
"op has no mask, needs unmasked pattern");
274 auto padOp = tileLoadOp.getPadding();
275 assert(padOp &&
"expected padding when masking!");
277 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
280 tileLoadOp,
"unsupported mask op, only 'vector.create_mask' is "
281 "currently supported");
283 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
285 constPadOp.getValue() == rewriter.
getZeroAttr(tileElementType))
287 tileLoadOp,
"op has constant zero pad, needs zero pad pattern");
289 auto numRows = createMaskOp.getOperands()[0];
290 auto numCols = createMaskOp.getOperands()[1];
292 auto numColsI32 = rewriter.
create<arith::IndexCastUIOp>(
296 auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
297 rewriter, loc, tileType);
300 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
301 auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
304 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
305 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
307 rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
308 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
311 rewriter.setInsertionPointToStart(forOp.getBody());
313 auto tileSliceIndex = forOp.getInductionVar();
314 auto currentTile = forOp.getRegionIterArg(0);
317 auto rowIsActive = rewriter.create<arith::CmpIOp>(
318 loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
319 auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
320 loc, rewriter.getI32Type(), rowIsActive);
321 auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
323 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
326 auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
327 loc, predicateType, maskIndex.getResult());
329 auto memrefIndices = getMemrefIndices(
330 tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
331 tileSliceIndex, numTileSlices, loc, rewriter);
335 auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
337 auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
338 loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
343 tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
344 rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
345 tileSliceIndex, tileLoadOp.getLayout());
346 rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());
348 rewriter.setInsertionPointAfter(forOp);
351 rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
378 struct TileStoreOpConversion :
public OpRewritePattern<arm_sme::TileStoreOp> {
381 LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
384 return createLoadStoreForOverTileSlices(
385 rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
386 tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
387 tileStoreOp.getMask(),
389 tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
390 rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
391 predicate, tileStoreOp.getBase(), memrefIndices,
392 tileStoreOp.getLayout());
400 patterns.
add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
401 TileStoreOpConversion>(patterns.
getContext());
406 struct ConvertArmSMEToSCFPass
407 :
public impl::ConvertArmSMEToSCFBase<ConvertArmSMEToSCFPass> {
408 void runOnOperation()
override {
412 target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
413 arith::ArithDialect, scf::SCFDialect>();
414 target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
416 std::move(patterns))))
424 return std::make_unique<ConvertArmSMEToSCFPass>();
static MLIRContext * getContext(OpFoldResult val)
TypedAttr getZeroAttr(Type type)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToSCFPass()
Create a pass to convert a subset of ArmSME ops to SCF.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...