MLIR  21.0.0git
ArmSMEToSCF.cpp
Go to the documentation of this file.
1 //===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTARMSMETOSCFPASS
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 /// Returns adjusted (1-D or 2-D) `indices` for a tile slice as follows:
30 /// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
31 /// rank 2: (indices[0] + tileSliceIndex, indices[1])
32 SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
33  Value tileSliceIndex,
34  Value tileSliceNumElts, Location loc,
35  PatternRewriter &rewriter) {
36  assert(rank == 2 && "memref has unexpected rank!");
37  SmallVector<Value, 2> outIndices;
38 
39  auto tileSliceOffset = tileSliceIndex;
40 
41  auto baseIndexPlusTileSliceOffset =
42  rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
43  outIndices.push_back(baseIndexPlusTileSliceOffset);
44  outIndices.push_back(indices[1]);
45 
46  return outIndices;
47 }
48 
49 /// Creates an scf.for for the load/store of an ArmSME tile.
50 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
51  PatternRewriter &rewriter, Location loc, VectorType tileType,
52  ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
53  function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
54  /*currentTile=*/Value)>
55  makeLoopBody) {
56  PatternRewriter::InsertionGuard guard(rewriter);
57 
58  // TODO: This case should be captured and rejected by a verifier.
59  if (memrefIndices.size() != 2)
60  return rewriter.notifyMatchFailure(loc, "invalid number of indices");
61 
62  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
63  loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
64  auto vscale =
65  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
66  auto predicateType =
67  VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
68 
69  // This describes both the number of ZA tile slices and the number of
70  // elements in a vector of SVL bits for a given element type (SVL_B,
71  // SVL_H, ..., SVL_Q).
72  auto numTileSlices =
73  rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
74 
75  Value predicate;
76  Value upperBound;
77  if (mask) {
78  auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
79  auto maskDim0 = createMaskOp.getOperands()[0];
80  auto maskDim1 = createMaskOp.getOperands()[1];
81 
82  // The upper bound of the loop must be clamped at `numTileSlices` as
83  // `vector.create_mask` allows operands to be greater than the size of a
84  // dimension.
85  auto numRowI64 = rewriter.create<arith::IndexCastOp>(
86  loc, rewriter.getI64Type(), maskDim0);
87  auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
88  loc, rewriter.getI64Type(), numTileSlices);
89  auto upperBoundI64 =
90  rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
91  upperBound = rewriter.create<arith::IndexCastOp>(
92  loc, rewriter.getIndexType(), upperBoundI64);
93 
94  predicate =
95  rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
96  } else {
97  upperBound = numTileSlices;
98  // No mask. Create an 'all true' predicate for the tile slice.
99  predicate = rewriter.create<arith::ConstantOp>(
100  loc, DenseElementsAttr::get(predicateType, true));
101  }
102 
103  bool hasCarriedArgs = bool(initTile);
104  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
105  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
106  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
107  hasCarriedArgs ? ValueRange{initTile}
108  : ValueRange{});
109 
110  rewriter.setInsertionPointToStart(forOp.getBody());
111  Value tileSliceIndex = forOp.getInductionVar();
112 
113  auto adjustedIndices = getMemrefIndices(
114  memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
115  auto nextTile = makeLoopBody(
116  tileSliceIndex, adjustedIndices, predicate,
117  /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
118 
119  assert(bool(nextTile) == hasCarriedArgs);
120  if (nextTile)
121  rewriter.create<scf::YieldOp>(loc, nextTile);
122 
123  return forOp;
124 }
125 
126 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
127  PatternRewriter &rewriter, Location loc, VectorType tileType,
128  ValueRange memrefIndices, int memrefRank, Value mask,
129  function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
130  makeLoopBody) {
131  return createLoadStoreForOverTileSlices(
132  rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
133  [&](Value index, ValueRange adjustedIndices, Value predicate,
134  Value) -> Value {
135  makeLoopBody(index, adjustedIndices, predicate);
136  return {};
137  });
138 }
139 
140 /// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
141 ///
142 /// With a mask:
143 ///
144 /// BEFORE:
145 /// ```mlir
146 /// %pad = arith.constant 0 : i32
147 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
148 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
149 /// memref<?x?xi32>, vector<[4]x[4]xi32>
150 /// ```
151 ///
152 /// AFTER:
153 /// ```mlir
154 /// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
155 /// %mask_cols = vector.create_mask %num_cols : vector<[4]xi1>
156 /// %loop_rows = arith.minsi %num_rows, %svl_s : index
157 /// %tile = scf.for %tile_slice_idx = %c0 to %loop_rows step %c1
158 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
159 /// %tile_update = arm_sme.load_tile_slice
160 /// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
161 /// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
162 /// scf.yield %tile_update : vector<[4]x[4]xi32>
163 /// }
164 /// ```
165 ///
166 /// Without a mask the lowering is pretty much identical. The only difference is
167 /// %mask_cols becomes an all-true mask, and %loop_rows becomes %svl_s.
168 ///
169 /// NOTE: Only mask of 'vector.create_mask' op is currently supported.
170 struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
172 
173  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
174  PatternRewriter &rewriter) const override {
175  auto loc = tileLoadOp.getLoc();
176  auto tileType = tileLoadOp.getVectorType();
177  auto mask = tileLoadOp.getMask();
178 
179  Value initTile;
180  if (mask) {
181  if (!mask.getDefiningOp<vector::CreateMaskOp>())
182  return rewriter.notifyMatchFailure(
183  loc, "unsupported mask op, only 'vector.create_mask' is "
184  "currently supported");
185  auto padOp = tileLoadOp.getPadding();
186  assert(padOp && "expected padding when masking!");
187 
188  auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
189  if (!constPadOp || constPadOp.getValue() !=
190  rewriter.getZeroAttr(tileType.getElementType()))
191  return rewriter.notifyMatchFailure(
192  tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
193 
194  // Initialize tile with zero to satisfy padding. Inactive cols will be
195  // zeroed anyway since the loads use zeroing predication. For inactive
196  // rows however, no load will occur so these need to be zeroed.
197  initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
198  } else {
199  initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
200  }
201 
202  // Create a loop to load the active tile slices from memory.
203  auto forOp = createLoadStoreForOverTileSlices(
204  rewriter, loc, tileType, tileLoadOp.getIndices(),
205  tileLoadOp.getMemRefType().getRank(), mask, initTile,
206  [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
207  Value currentTile) -> Value {
208  // Create 'arm_sme.load_tile_slice' to load tile slice from memory
209  // into tile.
210  return rewriter.create<arm_sme::LoadTileSliceOp>(
211  loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
212  memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
213  });
214 
215  if (failed(forOp))
216  return forOp;
217 
218  // Replace 'arm_sme.tile_load' with the result.
219  rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
220 
221  return success();
222  }
223 };
224 
225 /// Lower `arm_sme.tile_load` with mask and non-zero pad.
226 ///
227 /// BEFORE:
228 /// ```mlir
229 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
230 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
231 /// memref<?x?xi32>, vector<[4]x[4]xi32>
232 /// ```
233 ///
234 /// AFTER:
235 /// ```mlir
236 /// ...
237 /// %pad_1d = vector.splat %pad : vector<[4]xi32>
238 /// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
239 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
240 /// ...
241 /// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
242 /// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
243 /// : memref<?x?xi32>, vector<[4]xi1>,
244 /// vector<[4]xi32> into vector<[4]xi32>
245 /// // Insert slice into tile
246 /// %tile_update = arm_sme.insert_tile_slice
247 /// %slice, %iter_tile[%tile_slice_idx] :
248 /// vector<[4]xi32> into vector<[4]x[4]xi32>
249 /// scf.yield %tile_update : vector<[4]x[4]xi32>
250 /// }
251 /// ```
252 struct TileLoadOpWithMaskAndPadNonZeroConversion
253  : public OpRewritePattern<arm_sme::TileLoadOp> {
255 
256  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
257  PatternRewriter &rewriter) const override {
258  OpBuilder::InsertionGuard g(rewriter);
259  auto loc = tileLoadOp.getLoc();
260  auto tileType = tileLoadOp.getVectorType();
261  auto tileElementType = tileType.getElementType();
262 
263  auto maskOp = tileLoadOp.getMask();
264  if (!maskOp)
265  return rewriter.notifyMatchFailure(
266  tileLoadOp, "op has no mask, needs unmasked pattern");
267 
268  auto padOp = tileLoadOp.getPadding();
269  assert(padOp && "expected padding when masking!");
270 
271  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
272  if (!createMaskOp)
273  return rewriter.notifyMatchFailure(
274  tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
275  "currently supported");
276 
277  auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
278  if (constPadOp &&
279  constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
280  return rewriter.notifyMatchFailure(
281  tileLoadOp, "op has constant zero pad, needs zero pad pattern");
282 
283  auto numRows = createMaskOp.getOperands()[0];
284  auto numCols = createMaskOp.getOperands()[1];
285 
286  auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
287  loc, rewriter.getI32Type(), numCols);
288 
289  auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
290 
291  // Create a loop that loads each ZA tile slice from memory.
292  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
293  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
294  loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
295  auto vscale =
296  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
297  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
298  auto numTileSlices =
299  rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
300  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
301  step, ValueRange{initTile});
302 
303  rewriter.setInsertionPointToStart(forOp.getBody());
304 
305  auto tileSliceIndex = forOp.getInductionVar();
306  auto currentTile = forOp.getRegionIterArg(0);
307 
308  // Combine masks.
309  auto rowIsActive = rewriter.create<arith::CmpIOp>(
310  loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
311  auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
312  loc, rewriter.getI32Type(), rowIsActive);
313  auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
314  auto maskIndex =
315  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
316  auto predicateType =
317  VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
318  auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
319  loc, predicateType, maskIndex.getResult());
320 
321  auto memrefIndices = getMemrefIndices(
322  tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
323  tileSliceIndex, numTileSlices, loc, rewriter);
324 
325  // Splat pad into 1-D vector matching type of tile slice.
326  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
327  auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
328 
329  auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
330  loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
331  /*passthru=*/pad1DOp);
332 
333  // Create 'arm_sme.insert_tile_slice' to insert slice into tile.
334  auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>(
335  loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
336  tileLoadOp.getLayout());
337  rewriter.create<scf::YieldOp>(loc, insertSlice.getResult());
338 
339  rewriter.setInsertionPointAfter(forOp);
340 
341  // Replace 'arm_sme.tile_load' with the result.
342  rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
343 
344  return success();
345  }
346 };
347 
348 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
349 /// slice using `arm_sme.store_tile_slice`.
350 ///
351 /// BEFORE:
352 /// ```mlir
353 /// arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
354 /// : memref<?x?xi32>, vector<[4]x[4]xi32
355 /// ```
356 ///
357 /// AFTER:
358 /// ```mlir
359 /// %vscale = vector.vscale
360 /// %c0 = arith.constant 0 : index
361 /// %c1 = arith.constant 1 : index
362 /// %min_svl_s = arith.constant 4 : index
363 /// %svl_s = arith.muli %min_svl_s, %vscale : index
364 /// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
365 /// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
366 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
367 /// }
368 /// ```
369 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
371 
372  LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
373  PatternRewriter &rewriter) const override {
374  if (Value mask = tileStoreOp.getMask()) {
375  if (!mask.getDefiningOp<vector::CreateMaskOp>())
376  return rewriter.notifyMatchFailure(
377  tileStoreOp.getLoc(),
378  "unsupported mask op, only 'vector.create_mask' is "
379  "currently supported");
380  }
381 
382  // Create a loop that stores each active ZA tile slice from memory.
383  return createLoadStoreForOverTileSlices(
384  rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
385  tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
386  tileStoreOp.getMask(),
387  [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
388  rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
389  tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
390  predicate, tileStoreOp.getBase(), memrefIndices,
391  tileStoreOp.getLayout());
392  });
393  }
394 };
395 
396 } // namespace
397 
399  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
400  TileStoreOpConversion>(patterns.getContext());
401 }
402 
403 namespace {
404 
405 struct ConvertArmSMEToSCFPass
406  : public impl::ConvertArmSMEToSCFPassBase<ConvertArmSMEToSCFPass> {
407  void runOnOperation() override {
409  ConversionTarget target(getContext());
411  target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
412  arith::ArithDialect, scf::SCFDialect>();
413  target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
414  if (failed(applyPartialConversion(getOperation(), target,
415  std::move(patterns))))
416  signalPassFailure();
417  }
418 };
419 
420 } // namespace
static MLIRContext * getContext(OpFoldResult val)
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:409
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:295
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
Include the generated interface declarations.
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314