MLIR  21.0.0git
ArmSMEToSCF.cpp
Go to the documentation of this file.
1 //===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTARMSMETOSCFPASS
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 /// Returns adjusted (1-D or 2-D) `indices` for a tile slice as follows:
30 /// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
31 /// rank 2: (indices[0] + tileSliceIndex, indices[1])
32 SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
33  Value tileSliceIndex,
34  Value tileSliceNumElts, Location loc,
35  PatternRewriter &rewriter) {
36  assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
37  SmallVector<Value, 2> outIndices;
38 
39  auto tileSliceOffset = tileSliceIndex;
40  if (rank == 1)
41  tileSliceOffset =
42  rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
43 
44  auto baseIndexPlusTileSliceOffset =
45  rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
46  outIndices.push_back(baseIndexPlusTileSliceOffset);
47 
48  if (rank == 2)
49  outIndices.push_back(indices[1]);
50 
51  return outIndices;
52 }
53 
54 /// Creates an scf.for for the load/store of an ArmSME tile.
55 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
56  PatternRewriter &rewriter, Location loc, VectorType tileType,
57  ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
58  function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
59  /*currentTile=*/Value)>
60  makeLoopBody) {
61  PatternRewriter::InsertionGuard guard(rewriter);
62 
63  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
64  loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
65  auto vscale =
66  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
67  auto predicateType =
68  VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
69 
70  // This describes both the number of ZA tile slices and the number of
71  // elements in a vector of SVL bits for a given element type (SVL_B,
72  // SVL_H, ..., SVL_Q).
73  auto numTileSlices =
74  rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
75 
76  Value predicate;
77  Value upperBound;
78  if (mask) {
79  auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
80  auto maskDim0 = createMaskOp.getOperands()[0];
81  auto maskDim1 = createMaskOp.getOperands()[1];
82 
83  // The upper bound of the loop must be clamped at `numTileSlices` as
84  // `vector.create_mask` allows operands to be greater than the size of a
85  // dimension.
86  auto numRowI64 = rewriter.create<arith::IndexCastOp>(
87  loc, rewriter.getI64Type(), maskDim0);
88  auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
89  loc, rewriter.getI64Type(), numTileSlices);
90  auto upperBoundI64 =
91  rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
92  upperBound = rewriter.create<arith::IndexCastOp>(
93  loc, rewriter.getIndexType(), upperBoundI64);
94 
95  predicate =
96  rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
97  } else {
98  upperBound = numTileSlices;
99  // No mask. Create an 'all true' predicate for the tile slice.
100  predicate = rewriter.create<arith::ConstantOp>(
101  loc, DenseElementsAttr::get(predicateType, true));
102  }
103 
104  bool hasCarriedArgs = bool(initTile);
105  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
106  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
107  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
108  hasCarriedArgs ? ValueRange{initTile}
109  : ValueRange{});
110 
111  rewriter.setInsertionPointToStart(forOp.getBody());
112  Value tileSliceIndex = forOp.getInductionVar();
113 
114  auto adjustedIndices = getMemrefIndices(
115  memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
116  auto nextTile = makeLoopBody(
117  tileSliceIndex, adjustedIndices, predicate,
118  /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
119 
120  assert(bool(nextTile) == hasCarriedArgs);
121  if (nextTile)
122  rewriter.create<scf::YieldOp>(loc, nextTile);
123 
124  return forOp;
125 }
126 
127 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
128  PatternRewriter &rewriter, Location loc, VectorType tileType,
129  ValueRange memrefIndices, int memrefRank, Value mask,
130  function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
131  makeLoopBody) {
132  return createLoadStoreForOverTileSlices(
133  rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
134  [&](Value index, ValueRange adjustedIndices, Value predicate,
135  Value) -> Value {
136  makeLoopBody(index, adjustedIndices, predicate);
137  return {};
138  });
139 }
140 
141 /// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
142 ///
143 /// With a mask:
144 ///
145 /// BEFORE:
146 /// ```mlir
147 /// %pad = arith.constant 0 : i32
148 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
149 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
150 /// memref<?x?xi32>, vector<[4]x[4]xi32>
151 /// ```
152 ///
153 /// AFTER:
154 /// ```mlir
155 /// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
156 /// %mask_cols = vector.create_mask %num_cols : vector<[4]xi1>
157 /// %loop_rows = arith.minsi %num_rows, %svl_s : index
158 /// %tile = scf.for %tile_slice_idx = %c0 to %loop_rows step %c1
159 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
160 /// %tile_update = arm_sme.load_tile_slice
161 /// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
162 /// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
163 /// scf.yield %tile_update : vector<[4]x[4]xi32>
164 /// }
165 /// ```
166 ///
167 /// Without a mask the lowering is pretty much identical. The only difference is
168 /// %mask_cols becomes an all-true mask, and %loop_rows becomes %svl_s.
169 ///
170 /// NOTE: Only mask of 'vector.create_mask' op is currently supported.
171 struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
173 
174  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
175  PatternRewriter &rewriter) const override {
176  auto loc = tileLoadOp.getLoc();
177  auto tileType = tileLoadOp.getVectorType();
178  auto mask = tileLoadOp.getMask();
179 
180  Value initTile;
181  if (mask) {
182  if (!mask.getDefiningOp<vector::CreateMaskOp>())
183  return rewriter.notifyMatchFailure(
184  loc, "unsupported mask op, only 'vector.create_mask' is "
185  "currently supported");
186  auto padOp = tileLoadOp.getPadding();
187  assert(padOp && "expected padding when masking!");
188 
189  auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
190  if (!constPadOp || constPadOp.getValue() !=
191  rewriter.getZeroAttr(tileType.getElementType()))
192  return rewriter.notifyMatchFailure(
193  tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
194 
195  // Initialize tile with zero to satisfy padding. Inactive cols will be
196  // zeroed anyway since the loads use zeroing predication. For inactive
197  // rows however, no load will occur so these need to be zeroed.
198  initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
199  } else {
200  initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
201  }
202 
203  // Create a loop to load the active tile slices from memory.
204  auto forOp = createLoadStoreForOverTileSlices(
205  rewriter, loc, tileType, tileLoadOp.getIndices(),
206  tileLoadOp.getMemRefType().getRank(), mask, initTile,
207  [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
208  Value currentTile) -> Value {
209  // Create 'arm_sme.load_tile_slice' to load tile slice from memory
210  // into tile.
211  return rewriter.create<arm_sme::LoadTileSliceOp>(
212  loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
213  memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
214  });
215 
216  if (failed(forOp))
217  return forOp;
218 
219  // Replace 'arm_sme.tile_load' with the result.
220  rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
221 
222  return success();
223  }
224 };
225 
226 /// Lower `arm_sme.tile_load` with mask and non-zero pad.
227 ///
228 /// BEFORE:
229 /// ```mlir
230 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
231 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
232 /// memref<?x?xi32>, vector<[4]x[4]xi32>
233 /// ```
234 ///
235 /// AFTER:
236 /// ```mlir
237 /// ...
238 /// %pad_1d = vector.splat %pad : vector<[4]xi32>
239 /// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
240 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
241 /// ...
242 /// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
243 /// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
244 /// : memref<?x?xi32>, vector<[4]xi1>,
245 /// vector<[4]xi32> into vector<[4]xi32>
246 /// // Insert slice into tile
247 /// %tile_update = arm_sme.insert_tile_slice
248 /// %slice, %iter_tile[%tile_slice_idx] :
249 /// vector<[4]xi32> into vector<[4]x[4]xi32>
250 /// scf.yield %tile_update : vector<[4]x[4]xi32>
251 /// }
252 /// ```
253 struct TileLoadOpWithMaskAndPadNonZeroConversion
254  : public OpRewritePattern<arm_sme::TileLoadOp> {
256 
257  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
258  PatternRewriter &rewriter) const override {
259  OpBuilder::InsertionGuard g(rewriter);
260  auto loc = tileLoadOp.getLoc();
261  auto tileType = tileLoadOp.getVectorType();
262  auto tileElementType = tileType.getElementType();
263 
264  auto maskOp = tileLoadOp.getMask();
265  if (!maskOp)
266  return rewriter.notifyMatchFailure(
267  tileLoadOp, "op has no mask, needs unmasked pattern");
268 
269  auto padOp = tileLoadOp.getPadding();
270  assert(padOp && "expected padding when masking!");
271 
272  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
273  if (!createMaskOp)
274  return rewriter.notifyMatchFailure(
275  tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
276  "currently supported");
277 
278  auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
279  if (constPadOp &&
280  constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
281  return rewriter.notifyMatchFailure(
282  tileLoadOp, "op has constant zero pad, needs zero pad pattern");
283 
284  auto numRows = createMaskOp.getOperands()[0];
285  auto numCols = createMaskOp.getOperands()[1];
286 
287  auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
288  loc, rewriter.getI32Type(), numCols);
289 
290  auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
291 
292  // Create a loop that loads each ZA tile slice from memory.
293  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
294  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
295  loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
296  auto vscale =
297  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
298  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
299  auto numTileSlices =
300  rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
301  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
302  step, ValueRange{initTile});
303 
304  rewriter.setInsertionPointToStart(forOp.getBody());
305 
306  auto tileSliceIndex = forOp.getInductionVar();
307  auto currentTile = forOp.getRegionIterArg(0);
308 
309  // Combine masks.
310  auto rowIsActive = rewriter.create<arith::CmpIOp>(
311  loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
312  auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
313  loc, rewriter.getI32Type(), rowIsActive);
314  auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
315  auto maskIndex =
316  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
317  auto predicateType =
318  VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
319  auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
320  loc, predicateType, maskIndex.getResult());
321 
322  auto memrefIndices = getMemrefIndices(
323  tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
324  tileSliceIndex, numTileSlices, loc, rewriter);
325 
326  // Splat pad into 1-D vector matching type of tile slice.
327  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
328  auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
329 
330  auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
331  loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
332  /*passthru=*/pad1DOp);
333 
334  // Create 'arm_sme.insert_tile_slice' to insert slice into tile.
335  auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>(
336  loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
337  tileLoadOp.getLayout());
338  rewriter.create<scf::YieldOp>(loc, insertSlice.getResult());
339 
340  rewriter.setInsertionPointAfter(forOp);
341 
342  // Replace 'arm_sme.tile_load' with the result.
343  rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
344 
345  return success();
346  }
347 };
348 
349 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
350 /// slice using `arm_sme.store_tile_slice`.
351 ///
352 /// BEFORE:
353 /// ```mlir
354 /// arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
355 /// : memref<?x?xi32>, vector<[4]x[4]xi32
356 /// ```
357 ///
358 /// AFTER:
359 /// ```mlir
360 /// %vscale = vector.vscale
361 /// %c0 = arith.constant 0 : index
362 /// %c1 = arith.constant 1 : index
363 /// %min_svl_s = arith.constant 4 : index
364 /// %svl_s = arith.muli %min_svl_s, %vscale : index
365 /// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
366 /// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
367 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
368 /// }
369 /// ```
370 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
372 
373  LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
374  PatternRewriter &rewriter) const override {
375  if (Value mask = tileStoreOp.getMask()) {
376  if (!mask.getDefiningOp<vector::CreateMaskOp>())
377  return rewriter.notifyMatchFailure(
378  tileStoreOp.getLoc(),
379  "unsupported mask op, only 'vector.create_mask' is "
380  "currently supported");
381  }
382 
383  // Create a loop that stores each active ZA tile slice from memory.
384  return createLoadStoreForOverTileSlices(
385  rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
386  tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
387  tileStoreOp.getMask(),
388  [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
389  rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
390  tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
391  predicate, tileStoreOp.getBase(), memrefIndices,
392  tileStoreOp.getLayout());
393  });
394  }
395 };
396 
397 } // namespace
398 
400  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
401  TileStoreOpConversion>(patterns.getContext());
402 }
403 
404 namespace {
405 
406 struct ConvertArmSMEToSCFPass
407  : public impl::ConvertArmSMEToSCFPassBase<ConvertArmSMEToSCFPass> {
408  void runOnOperation() override {
410  ConversionTarget target(getContext());
412  target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
413  arith::ArithDialect, scf::SCFDialect>();
414  target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
415  if (failed(applyPartialConversion(getOperation(), target,
416  std::move(patterns))))
417  signalPassFailure();
418  }
419 };
420 
421 } // namespace
static MLIRContext * getContext(OpFoldResult val)
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:295
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
Include the generated interface declarations.
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314