MLIR 22.0.0git
ArmSMEToSCF.cpp
Go to the documentation of this file.
1//===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of ArmSME operations to SCF.
10//
11//===----------------------------------------------------------------------===//
13
18#include "mlir/Pass/Pass.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTARMSMETOSCFPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27
28namespace {
29/// Returns adjusted (1-D or 2-D) `indices` for a tile slice as follows:
30/// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
31/// rank 2: (indices[0] + tileSliceIndex, indices[1])
32SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
33 Value tileSliceIndex,
34 Value tileSliceNumElts, Location loc,
35 PatternRewriter &rewriter) {
36 assert(rank == 2 && "memref has unexpected rank!");
37 SmallVector<Value, 2> outIndices;
38
39 auto tileSliceOffset = tileSliceIndex;
40
41 auto baseIndexPlusTileSliceOffset =
42 arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset);
43 outIndices.push_back(baseIndexPlusTileSliceOffset);
44 outIndices.push_back(indices[1]);
45
46 return outIndices;
47}
48
49/// Creates an scf.for for the load/store of an ArmSME tile.
50FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
51 PatternRewriter &rewriter, Location loc, VectorType tileType,
52 ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
53 function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
54 /*currentTile=*/Value)>
55 makeLoopBody) {
56 PatternRewriter::InsertionGuard guard(rewriter);
57
58 // TODO: This case should be captured and rejected by a verifier.
59 if (memrefIndices.size() != 2)
60 return rewriter.notifyMatchFailure(loc, "invalid number of indices");
61
62 auto minTileSlices = arith::ConstantIndexOp::create(
63 rewriter, loc,
64 arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
65 auto vscale =
66 vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
67 auto predicateType =
68 VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
69
70 // This describes both the number of ZA tile slices and the number of
71 // elements in a vector of SVL bits for a given element type (SVL_B,
72 // SVL_H, ..., SVL_Q).
73 auto numTileSlices =
74 arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
75
76 Value predicate;
77 Value upperBound;
78 if (mask) {
79 auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
80 auto maskDim0 = createMaskOp.getOperands()[0];
81 auto maskDim1 = createMaskOp.getOperands()[1];
82
83 // The upper bound of the loop must be clamped at `numTileSlices` as
84 // `vector.create_mask` allows operands to be greater than the size of a
85 // dimension.
86 auto numRowI64 = arith::IndexCastOp::create(
87 rewriter, loc, rewriter.getI64Type(), maskDim0);
88 auto numTileSlicesI64 = arith::IndexCastOp::create(
89 rewriter, loc, rewriter.getI64Type(), numTileSlices);
90 auto upperBoundI64 =
91 arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64);
92 upperBound = arith::IndexCastOp::create(
93 rewriter, loc, rewriter.getIndexType(), upperBoundI64);
94
95 predicate =
96 vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1);
97 } else {
98 upperBound = numTileSlices;
99 // No mask. Create an 'all true' predicate for the tile slice.
100 predicate = arith::ConstantOp::create(
101 rewriter, loc, DenseElementsAttr::get(predicateType, true));
102 }
103
104 bool hasCarriedArgs = bool(initTile);
105 auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
106 auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
107 auto forOp =
108 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
109 hasCarriedArgs ? ValueRange{initTile} : ValueRange{});
110
111 rewriter.setInsertionPointToStart(forOp.getBody());
112 Value tileSliceIndex = forOp.getInductionVar();
113
114 auto adjustedIndices = getMemrefIndices(
115 memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
116 auto nextTile = makeLoopBody(
117 tileSliceIndex, adjustedIndices, predicate,
118 /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
119
120 assert(bool(nextTile) == hasCarriedArgs);
121 if (nextTile)
122 scf::YieldOp::create(rewriter, loc, nextTile);
123
124 return forOp;
125}
126
127FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
128 PatternRewriter &rewriter, Location loc, VectorType tileType,
129 ValueRange memrefIndices, int memrefRank, Value mask,
130 function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
131 makeLoopBody) {
132 return createLoadStoreForOverTileSlices(
133 rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
134 [&](Value index, ValueRange adjustedIndices, Value predicate,
135 Value) -> Value {
136 makeLoopBody(index, adjustedIndices, predicate);
137 return {};
138 });
139}
140
141/// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
142///
143/// With a mask:
144///
145/// BEFORE:
146/// ```mlir
147/// %pad = arith.constant 0 : i32
148/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
149/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
150/// memref<?x?xi32>, vector<[4]x[4]xi32>
151/// ```
152///
153/// AFTER:
154/// ```mlir
155/// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
156/// %mask_cols = vector.create_mask %num_cols : vector<[4]xi1>
157/// %loop_rows = arith.minsi %num_rows, %svl_s : index
158/// %tile = scf.for %tile_slice_idx = %c0 to %loop_rows step %c1
159/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
160/// %tile_update = arm_sme.load_tile_slice
161/// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
162/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
163/// scf.yield %tile_update : vector<[4]x[4]xi32>
164/// }
165/// ```
166///
167/// Without a mask the lowering is pretty much identical. The only difference is
168/// %mask_cols becomes an all-true mask, and %loop_rows becomes %svl_s.
169///
170/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
171struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
172 using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
173
174 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
175 PatternRewriter &rewriter) const override {
176 auto loc = tileLoadOp.getLoc();
177 auto tileType = tileLoadOp.getVectorType();
178 auto mask = tileLoadOp.getMask();
179
180 Value initTile;
181 if (mask) {
182 if (!mask.getDefiningOp<vector::CreateMaskOp>())
183 return rewriter.notifyMatchFailure(
184 loc, "unsupported mask op, only 'vector.create_mask' is "
185 "currently supported");
186 auto padOp = tileLoadOp.getPadding();
187 assert(padOp && "expected padding when masking!");
188
189 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
190 if (!constPadOp || constPadOp.getValue() !=
191 rewriter.getZeroAttr(tileType.getElementType()))
192 return rewriter.notifyMatchFailure(
193 tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
194
195 // Initialize tile with zero to satisfy padding. Inactive cols will be
196 // zeroed anyway since the loads use zeroing predication. For inactive
197 // rows however, no load will occur so these need to be zeroed.
198 initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType);
199 } else {
200 initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
201 }
202
203 // Create a loop to load the active tile slices from memory.
204 auto forOp = createLoadStoreForOverTileSlices(
205 rewriter, loc, tileType, tileLoadOp.getIndices(),
206 tileLoadOp.getMemRefType().getRank(), mask, initTile,
207 [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
208 Value currentTile) -> Value {
209 // Create 'arm_sme.load_tile_slice' to load tile slice from memory
210 // into tile.
211 return arm_sme::LoadTileSliceOp::create(
212 rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
213 currentTile, memrefIndices, tileSliceIndex,
214 tileLoadOp.getLayout());
215 });
216
217 if (failed(forOp))
218 return forOp;
219
220 // Replace 'arm_sme.tile_load' with the result.
221 rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
222
223 return success();
224 }
225};
226
227/// Lower `arm_sme.tile_load` with mask and non-zero pad.
228///
229/// BEFORE:
230/// ```mlir
231/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
232/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
233/// memref<?x?xi32>, vector<[4]x[4]xi32>
234/// ```
235///
236/// AFTER:
237/// ```mlir
238/// ...
239/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32>
240/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
241/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
242/// ...
243/// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
244/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
245/// : memref<?x?xi32>, vector<[4]xi1>,
246/// vector<[4]xi32> into vector<[4]xi32>
247/// // Insert slice into tile
248/// %tile_update = arm_sme.insert_tile_slice
249/// %slice, %iter_tile[%tile_slice_idx] :
250/// vector<[4]xi32> into vector<[4]x[4]xi32>
251/// scf.yield %tile_update : vector<[4]x[4]xi32>
252/// }
253/// ```
254struct TileLoadOpWithMaskAndPadNonZeroConversion
255 : public OpRewritePattern<arm_sme::TileLoadOp> {
256 using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
257
258 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
259 PatternRewriter &rewriter) const override {
260 OpBuilder::InsertionGuard g(rewriter);
261 auto loc = tileLoadOp.getLoc();
262 auto tileType = tileLoadOp.getVectorType();
263 auto tileElementType = tileType.getElementType();
264
265 auto maskOp = tileLoadOp.getMask();
266 if (!maskOp)
267 return rewriter.notifyMatchFailure(
268 tileLoadOp, "op has no mask, needs unmasked pattern");
269
270 auto padOp = tileLoadOp.getPadding();
271 assert(padOp && "expected padding when masking!");
272
273 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
274 if (!createMaskOp)
275 return rewriter.notifyMatchFailure(
276 tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
277 "currently supported");
278
279 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
280 if (constPadOp &&
281 constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
282 return rewriter.notifyMatchFailure(
283 tileLoadOp, "op has constant zero pad, needs zero pad pattern");
284
285 auto numRows = createMaskOp.getOperands()[0];
286 auto numCols = createMaskOp.getOperands()[1];
287
288 auto numColsI32 = arith::IndexCastUIOp::create(
289 rewriter, loc, rewriter.getI32Type(), numCols);
290
291 auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
292
293 // Create a loop that loads each ZA tile slice from memory.
294 auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
295 auto minTileSlices = arith::ConstantIndexOp::create(
296 rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
297 auto vscale =
298 vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
299 auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
300 auto numTileSlices =
301 arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
302 auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
303 step, ValueRange{initTile});
304
305 rewriter.setInsertionPointToStart(forOp.getBody());
306
307 auto tileSliceIndex = forOp.getInductionVar();
308 auto currentTile = forOp.getRegionIterArg(0);
309
310 // Combine masks.
311 auto rowIsActive = arith::CmpIOp::create(
312 rewriter, loc, arith::CmpIPredicate::slt, tileSliceIndex, numRows);
313 auto rowIsActiveI32 = arith::ExtSIOp::create(
314 rewriter, loc, rewriter.getI32Type(), rowIsActive);
315 auto mask =
316 arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32);
317 auto maskIndex = arith::IndexCastOp::create(rewriter, loc,
318 rewriter.getIndexType(), mask);
319 auto predicateType =
320 VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
321 auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType,
322 maskIndex.getResult());
323
324 auto memrefIndices = getMemrefIndices(
325 tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
326 tileSliceIndex, numTileSlices, loc, rewriter);
327
328 // Splat pad into 1-D vector matching type of tile slice.
329 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
330 auto pad1DOp =
331 vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
332
333 auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
334 tileLoadOp.getBase(),
335 memrefIndices, maskOp1D,
336 /*passthrough=*/pad1DOp);
337
338 // Create 'arm_sme.insert_tile_slice' to insert slice into tile.
339 auto insertSlice = arm_sme::InsertTileSliceOp::create(
340 rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
341 tileSliceIndex, tileLoadOp.getLayout());
342 scf::YieldOp::create(rewriter, loc, insertSlice.getResult());
343
344 rewriter.setInsertionPointAfter(forOp);
345
346 // Replace 'arm_sme.tile_load' with the result.
347 rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
348
349 return success();
350 }
351};
352
353/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
354/// slice using `arm_sme.store_tile_slice`.
355///
356/// BEFORE:
357/// ```mlir
358/// arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
359/// : memref<?x?xi32>, vector<[4]x[4]xi32
360/// ```
361///
362/// AFTER:
363/// ```mlir
364/// %vscale = vector.vscale
365/// %c0 = arith.constant 0 : index
366/// %c1 = arith.constant 1 : index
367/// %min_svl_s = arith.constant 4 : index
368/// %svl_s = arith.muli %min_svl_s, %vscale : index
369/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
370/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
371/// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
372/// }
373/// ```
374struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
375 using OpRewritePattern<arm_sme::TileStoreOp>::OpRewritePattern;
376
377 LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
378 PatternRewriter &rewriter) const override {
379 if (Value mask = tileStoreOp.getMask()) {
380 if (!mask.getDefiningOp<vector::CreateMaskOp>())
381 return rewriter.notifyMatchFailure(
382 tileStoreOp.getLoc(),
383 "unsupported mask op, only 'vector.create_mask' is "
384 "currently supported");
385 }
386
387 // Create a loop that stores each active ZA tile slice from memory.
388 return createLoadStoreForOverTileSlices(
389 rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
390 tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
391 tileStoreOp.getMask(),
392 [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
393 rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
394 tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
395 predicate, tileStoreOp.getBase(), memrefIndices,
396 tileStoreOp.getLayout());
397 });
398 }
399};
400
401} // namespace
402
404 patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
405 TileStoreOpConversion>(patterns.getContext());
406}
407
408namespace {
409
410struct ConvertArmSMEToSCFPass
411 : public impl::ConvertArmSMEToSCFPassBase<ConvertArmSMEToSCFPass> {
412 void runOnOperation() override {
416 target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
417 arith::ArithDialect, scf::SCFDialect>();
418 target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
419 if (failed(applyPartialConversion(getOperation(), target,
420 std::move(patterns))))
421 signalPassFailure();
422 }
423};
424
425} // namespace
return success()
b getContext())
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
IntegerType getI1Type()
Definition Builders.cpp:53
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition Utils.cpp:32
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
const FrozenRewritePatternSet & patterns
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...