MLIR  22.0.0git
ArmSMEToSCF.cpp
Go to the documentation of this file.
1 //===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTARMSMETOSCFPASS
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 /// Returns adjusted (1-D or 2-D) `indices` for a tile slice as follows:
30 /// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
31 /// rank 2: (indices[0] + tileSliceIndex, indices[1])
32 SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
33  Value tileSliceIndex,
34  Value tileSliceNumElts, Location loc,
35  PatternRewriter &rewriter) {
36  assert(rank == 2 && "memref has unexpected rank!");
37  SmallVector<Value, 2> outIndices;
38 
39  auto tileSliceOffset = tileSliceIndex;
40 
41  auto baseIndexPlusTileSliceOffset =
42  arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset);
43  outIndices.push_back(baseIndexPlusTileSliceOffset);
44  outIndices.push_back(indices[1]);
45 
46  return outIndices;
47 }
48 
49 /// Creates an scf.for for the load/store of an ArmSME tile.
50 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
51  PatternRewriter &rewriter, Location loc, VectorType tileType,
52  ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
53  function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
54  /*currentTile=*/Value)>
55  makeLoopBody) {
56  PatternRewriter::InsertionGuard guard(rewriter);
57 
58  // TODO: This case should be captured and rejected by a verifier.
59  if (memrefIndices.size() != 2)
60  return rewriter.notifyMatchFailure(loc, "invalid number of indices");
61 
62  auto minTileSlices = arith::ConstantIndexOp::create(
63  rewriter, loc,
64  arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
65  auto vscale =
66  vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
67  auto predicateType =
68  VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
69 
70  // This describes both the number of ZA tile slices and the number of
71  // elements in a vector of SVL bits for a given element type (SVL_B,
72  // SVL_H, ..., SVL_Q).
73  auto numTileSlices =
74  arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
75 
76  Value predicate;
77  Value upperBound;
78  if (mask) {
79  auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
80  auto maskDim0 = createMaskOp.getOperands()[0];
81  auto maskDim1 = createMaskOp.getOperands()[1];
82 
83  // The upper bound of the loop must be clamped at `numTileSlices` as
84  // `vector.create_mask` allows operands to be greater than the size of a
85  // dimension.
86  auto numRowI64 = arith::IndexCastOp::create(
87  rewriter, loc, rewriter.getI64Type(), maskDim0);
88  auto numTileSlicesI64 = arith::IndexCastOp::create(
89  rewriter, loc, rewriter.getI64Type(), numTileSlices);
90  auto upperBoundI64 =
91  arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64);
92  upperBound = arith::IndexCastOp::create(
93  rewriter, loc, rewriter.getIndexType(), upperBoundI64);
94 
95  predicate =
96  vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1);
97  } else {
98  upperBound = numTileSlices;
99  // No mask. Create an 'all true' predicate for the tile slice.
100  predicate = arith::ConstantOp::create(
101  rewriter, loc, DenseElementsAttr::get(predicateType, true));
102  }
103 
104  bool hasCarriedArgs = bool(initTile);
105  auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
106  auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
107  auto forOp =
108  scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
109  hasCarriedArgs ? ValueRange{initTile} : ValueRange{});
110 
111  rewriter.setInsertionPointToStart(forOp.getBody());
112  Value tileSliceIndex = forOp.getInductionVar();
113 
114  auto adjustedIndices = getMemrefIndices(
115  memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
116  auto nextTile = makeLoopBody(
117  tileSliceIndex, adjustedIndices, predicate,
118  /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
119 
120  assert(bool(nextTile) == hasCarriedArgs);
121  if (nextTile)
122  scf::YieldOp::create(rewriter, loc, nextTile);
123 
124  return forOp;
125 }
126 
127 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
128  PatternRewriter &rewriter, Location loc, VectorType tileType,
129  ValueRange memrefIndices, int memrefRank, Value mask,
130  function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
131  makeLoopBody) {
132  return createLoadStoreForOverTileSlices(
133  rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
134  [&](Value index, ValueRange adjustedIndices, Value predicate,
135  Value) -> Value {
136  makeLoopBody(index, adjustedIndices, predicate);
137  return {};
138  });
139 }
140 
141 /// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
142 ///
143 /// With a mask:
144 ///
145 /// BEFORE:
146 /// ```mlir
147 /// %pad = arith.constant 0 : i32
148 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
149 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
150 /// memref<?x?xi32>, vector<[4]x[4]xi32>
151 /// ```
152 ///
153 /// AFTER:
154 /// ```mlir
155 /// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
156 /// %mask_cols = vector.create_mask %num_cols : vector<[4]xi1>
157 /// %loop_rows = arith.minsi %num_rows, %svl_s : index
158 /// %tile = scf.for %tile_slice_idx = %c0 to %loop_rows step %c1
159 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
160 /// %tile_update = arm_sme.load_tile_slice
161 /// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
162 /// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
163 /// scf.yield %tile_update : vector<[4]x[4]xi32>
164 /// }
165 /// ```
166 ///
167 /// Without a mask the lowering is pretty much identical. The only difference is
168 /// %mask_cols becomes an all-true mask, and %loop_rows becomes %svl_s.
169 ///
170 /// NOTE: Only mask of 'vector.create_mask' op is currently supported.
171 struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
173 
174  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
175  PatternRewriter &rewriter) const override {
176  auto loc = tileLoadOp.getLoc();
177  auto tileType = tileLoadOp.getVectorType();
178  auto mask = tileLoadOp.getMask();
179 
180  Value initTile;
181  if (mask) {
182  if (!mask.getDefiningOp<vector::CreateMaskOp>())
183  return rewriter.notifyMatchFailure(
184  loc, "unsupported mask op, only 'vector.create_mask' is "
185  "currently supported");
186  auto padOp = tileLoadOp.getPadding();
187  assert(padOp && "expected padding when masking!");
188 
189  auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
190  if (!constPadOp || constPadOp.getValue() !=
191  rewriter.getZeroAttr(tileType.getElementType()))
192  return rewriter.notifyMatchFailure(
193  tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
194 
195  // Initialize tile with zero to satisfy padding. Inactive cols will be
196  // zeroed anyway since the loads use zeroing predication. For inactive
197  // rows however, no load will occur so these need to be zeroed.
198  initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType);
199  } else {
200  initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
201  }
202 
203  // Create a loop to load the active tile slices from memory.
204  auto forOp = createLoadStoreForOverTileSlices(
205  rewriter, loc, tileType, tileLoadOp.getIndices(),
206  tileLoadOp.getMemRefType().getRank(), mask, initTile,
207  [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
208  Value currentTile) -> Value {
209  // Create 'arm_sme.load_tile_slice' to load tile slice from memory
210  // into tile.
211  return arm_sme::LoadTileSliceOp::create(
212  rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
213  currentTile, memrefIndices, tileSliceIndex,
214  tileLoadOp.getLayout());
215  });
216 
217  if (failed(forOp))
218  return forOp;
219 
220  // Replace 'arm_sme.tile_load' with the result.
221  rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
222 
223  return success();
224  }
225 };
226 
227 /// Lower `arm_sme.tile_load` with mask and non-zero pad.
228 ///
229 /// BEFORE:
230 /// ```mlir
231 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
232 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
233 /// memref<?x?xi32>, vector<[4]x[4]xi32>
234 /// ```
235 ///
236 /// AFTER:
237 /// ```mlir
238 /// ...
239 /// %pad_1d = vector.splat %pad : vector<[4]xi32>
240 /// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
241 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
242 /// ...
243 /// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
244 /// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
245 /// : memref<?x?xi32>, vector<[4]xi1>,
246 /// vector<[4]xi32> into vector<[4]xi32>
247 /// // Insert slice into tile
248 /// %tile_update = arm_sme.insert_tile_slice
249 /// %slice, %iter_tile[%tile_slice_idx] :
250 /// vector<[4]xi32> into vector<[4]x[4]xi32>
251 /// scf.yield %tile_update : vector<[4]x[4]xi32>
252 /// }
253 /// ```
254 struct TileLoadOpWithMaskAndPadNonZeroConversion
255  : public OpRewritePattern<arm_sme::TileLoadOp> {
257 
258  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
259  PatternRewriter &rewriter) const override {
260  OpBuilder::InsertionGuard g(rewriter);
261  auto loc = tileLoadOp.getLoc();
262  auto tileType = tileLoadOp.getVectorType();
263  auto tileElementType = tileType.getElementType();
264 
265  auto maskOp = tileLoadOp.getMask();
266  if (!maskOp)
267  return rewriter.notifyMatchFailure(
268  tileLoadOp, "op has no mask, needs unmasked pattern");
269 
270  auto padOp = tileLoadOp.getPadding();
271  assert(padOp && "expected padding when masking!");
272 
273  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
274  if (!createMaskOp)
275  return rewriter.notifyMatchFailure(
276  tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
277  "currently supported");
278 
279  auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
280  if (constPadOp &&
281  constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
282  return rewriter.notifyMatchFailure(
283  tileLoadOp, "op has constant zero pad, needs zero pad pattern");
284 
285  auto numRows = createMaskOp.getOperands()[0];
286  auto numCols = createMaskOp.getOperands()[1];
287 
288  auto numColsI32 = arith::IndexCastUIOp::create(
289  rewriter, loc, rewriter.getI32Type(), numCols);
290 
291  auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
292 
293  // Create a loop that loads each ZA tile slice from memory.
294  auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
295  auto minTileSlices = arith::ConstantIndexOp::create(
296  rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
297  auto vscale =
298  vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
299  auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
300  auto numTileSlices =
301  arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
302  auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
303  step, ValueRange{initTile});
304 
305  rewriter.setInsertionPointToStart(forOp.getBody());
306 
307  auto tileSliceIndex = forOp.getInductionVar();
308  auto currentTile = forOp.getRegionIterArg(0);
309 
310  // Combine masks.
311  auto rowIsActive = arith::CmpIOp::create(
312  rewriter, loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
313  auto rowIsActiveI32 = arith::ExtSIOp::create(
314  rewriter, loc, rewriter.getI32Type(), rowIsActive);
315  auto mask =
316  arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32);
317  auto maskIndex = arith::IndexCastOp::create(rewriter, loc,
318  rewriter.getIndexType(), mask);
319  auto predicateType =
320  VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
321  auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType,
322  maskIndex.getResult());
323 
324  auto memrefIndices = getMemrefIndices(
325  tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
326  tileSliceIndex, numTileSlices, loc, rewriter);
327 
328  // Splat pad into 1-D vector matching type of tile slice.
329  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
330  auto pad1DOp =
331  vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
332 
333  auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
334  tileLoadOp.getBase(),
335  memrefIndices, maskOp1D,
336  /*passthrough=*/pad1DOp);
337 
338  // Create 'arm_sme.insert_tile_slice' to insert slice into tile.
339  auto insertSlice = arm_sme::InsertTileSliceOp::create(
340  rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
341  tileSliceIndex, tileLoadOp.getLayout());
342  scf::YieldOp::create(rewriter, loc, insertSlice.getResult());
343 
344  rewriter.setInsertionPointAfter(forOp);
345 
346  // Replace 'arm_sme.tile_load' with the result.
347  rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
348 
349  return success();
350  }
351 };
352 
353 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
354 /// slice using `arm_sme.store_tile_slice`.
355 ///
356 /// BEFORE:
357 /// ```mlir
358 /// arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
359 /// : memref<?x?xi32>, vector<[4]x[4]xi32
360 /// ```
361 ///
362 /// AFTER:
363 /// ```mlir
364 /// %vscale = vector.vscale
365 /// %c0 = arith.constant 0 : index
366 /// %c1 = arith.constant 1 : index
367 /// %min_svl_s = arith.constant 4 : index
368 /// %svl_s = arith.muli %min_svl_s, %vscale : index
369 /// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
370 /// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
371 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
372 /// }
373 /// ```
374 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
376 
377  LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
378  PatternRewriter &rewriter) const override {
379  if (Value mask = tileStoreOp.getMask()) {
380  if (!mask.getDefiningOp<vector::CreateMaskOp>())
381  return rewriter.notifyMatchFailure(
382  tileStoreOp.getLoc(),
383  "unsupported mask op, only 'vector.create_mask' is "
384  "currently supported");
385  }
386 
387  // Create a loop that stores each active ZA tile slice from memory.
388  return createLoadStoreForOverTileSlices(
389  rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
390  tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
391  tileStoreOp.getMask(),
392  [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
393  rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
394  tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
395  predicate, tileStoreOp.getBase(), memrefIndices,
396  tileStoreOp.getLayout());
397  });
398  }
399 };
400 
401 } // namespace
402 
404  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
405  TileStoreOpConversion>(patterns.getContext());
406 }
407 
408 namespace {
409 
410 struct ConvertArmSMEToSCFPass
411  : public impl::ConvertArmSMEToSCFPassBase<ConvertArmSMEToSCFPass> {
412  void runOnOperation() override {
414  ConversionTarget target(getContext());
416  target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
417  arith::ArithDialect, scf::SCFDialect>();
418  target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
419  if (failed(applyPartialConversion(getOperation(), target,
420  std::move(patterns))))
421  signalPassFailure();
422  }
423 };
424 
425 } // namespace
static MLIRContext * getContext(OpFoldResult val)
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:311
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:17
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314