MLIR  19.0.0git
ArmSMEToSCF.cpp
Go to the documentation of this file.
1 //===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTARMSMETOSCF
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 /// Returns adjusted (1-D or 2-D) `indices` for a tile slice as follows:
30 /// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
31 /// rank 2: (indices[0] + tileSliceIndex, indices[1])
32 SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
33  Value tileSliceIndex,
34  Value tileSliceNumElts, Location loc,
35  PatternRewriter &rewriter) {
36  assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
37  SmallVector<Value, 2> outIndices;
38 
39  auto tileSliceOffset = tileSliceIndex;
40  if (rank == 1)
41  tileSliceOffset =
42  rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
43 
44  auto baseIndexPlusTileSliceOffset =
45  rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
46  outIndices.push_back(baseIndexPlusTileSliceOffset);
47 
48  if (rank == 2)
49  outIndices.push_back(indices[1]);
50 
51  return outIndices;
52 }
53 
54 /// Creates an scf.for for the load/store of an ArmSME tile.
55 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
56  PatternRewriter &rewriter, Location loc, VectorType tileType,
57  ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
58  function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
59  /*currentTile=*/Value)>
60  makeLoopBody) {
61  PatternRewriter::InsertionGuard guard(rewriter);
62 
63  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
64  loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
65  auto vscale =
66  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
67  auto predicateType =
68  VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
69 
70  // This describes both the number of ZA tile slices and the number of
71  // elements in a vector of SVL bits for a given element type (SVL_B,
72  // SVL_H, ..., SVL_Q).
73  auto numTileSlices =
74  rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
75 
76  Value predicate;
77  Value upperBound;
78  if (mask) {
79  auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
80  if (!createMaskOp)
81  return rewriter.notifyMatchFailure(
82  loc, "unsupported mask op, only 'vector.create_mask' is "
83  "currently supported");
84 
85  auto maskDim0 = createMaskOp.getOperands()[0];
86  auto maskDim1 = createMaskOp.getOperands()[1];
87 
88  // The upper bound of the loop must be clamped at `numTileSlices` as
89  // `vector.create_mask` allows operands to be greater than the size of a
90  // dimension.
91  auto numRowI64 = rewriter.create<arith::IndexCastOp>(
92  loc, rewriter.getI64Type(), maskDim0);
93  auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
94  loc, rewriter.getI64Type(), numTileSlices);
95  auto upperBoundI64 =
96  rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
97  upperBound = rewriter.create<arith::IndexCastOp>(
98  loc, rewriter.getIndexType(), upperBoundI64);
99 
100  predicate =
101  rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
102  } else {
103  upperBound = numTileSlices;
104  // No mask. Create an 'all true' predicate for the tile slice.
105  predicate = rewriter.create<arith::ConstantOp>(
106  loc, DenseElementsAttr::get(predicateType, true));
107  }
108 
109  bool hasCarriedArgs = bool(initTile);
110  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
111  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
112  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
113  hasCarriedArgs ? ValueRange{initTile}
114  : ValueRange{});
115 
116  rewriter.setInsertionPointToStart(forOp.getBody());
117  Value tileSliceIndex = forOp.getInductionVar();
118 
119  auto adjustedIndices = getMemrefIndices(
120  memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
121  auto nextTile = makeLoopBody(
122  tileSliceIndex, adjustedIndices, predicate,
123  /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
124 
125  assert(bool(nextTile) == hasCarriedArgs);
126  if (nextTile)
127  rewriter.create<scf::YieldOp>(loc, nextTile);
128 
129  return forOp;
130 }
131 
132 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
133  PatternRewriter &rewriter, Location loc, VectorType tileType,
134  ValueRange memrefIndices, int memrefRank, Value mask,
135  function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
136  makeLoopBody) {
137  return createLoadStoreForOverTileSlices(
138  rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
139  [&](Value index, ValueRange adjustedIndices, Value predicate,
140  Value) -> Value {
141  makeLoopBody(index, adjustedIndices, predicate);
142  return {};
143  });
144 }
145 
146 /// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
147 ///
148 /// With a mask:
149 ///
150 /// BEFORE:
151 /// ```mlir
152 /// %pad = arith.constant 0 : i32
153 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
154 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
155 /// memref<?x?xi32>, vector<[4]x[4]xi32>
156 /// ```
157 ///
158 /// AFTER:
159 /// ```mlir
160 /// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
161 /// %mask_cols = vector.create_mask %num_cols : vector<[4]xi1>
162 /// %loop_rows = arith.minsi %num_rows, %svl_s : index
163 /// %tile = scf.for %tile_slice_idx = %c0 to %loop_rows step %c1
164 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
165 /// %tile_update = arm_sme.load_tile_slice
166 /// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
167 /// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
168 /// scf.yield %tile_update : vector<[4]x[4]xi32>
169 /// }
170 /// ```
171 ///
172 /// Without a mask the lowering is pretty much identical. The only difference is
173 /// %mask_cols becomes an all-true mask, and %loop_rows becomes %svl_s.
174 ///
175 /// NOTE: Only mask of 'vector.create_mask' op is currently supported.
176 struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
178 
179  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
180  PatternRewriter &rewriter) const override {
181  auto loc = tileLoadOp.getLoc();
182  auto tileType = tileLoadOp.getVectorType();
183  auto mask = tileLoadOp.getMask();
184 
185  Value initTile;
186  if (mask) {
187  auto padOp = tileLoadOp.getPadding();
188  assert(padOp && "expected padding when masking!");
189 
190  auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
191  if (!constPadOp || constPadOp.getValue() !=
192  rewriter.getZeroAttr(tileType.getElementType()))
193  return rewriter.notifyMatchFailure(
194  tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
195 
196  // Initialize tile with zero to satisfy padding. Inactive cols will be
197  // zeroed anyway since the loads use zeroing predication. For inactive
198  // rows however, no load will occur so these need to be zeroed.
199  initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
200  rewriter, loc, tileType);
201  } else {
202  // Allocate a new SME tile.
203  initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
204  rewriter, loc, tileType);
205  }
206 
207  // Create a loop to load the active tile slices from memory.
208  auto forOp = createLoadStoreForOverTileSlices(
209  rewriter, loc, tileType, tileLoadOp.getIndices(),
210  tileLoadOp.getMemRefType().getRank(), mask, initTile,
211  [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
212  Value currentTile) -> Value {
213  // Create 'arm_sme.load_tile_slice' to load tile slice from memory
214  // into tile.
215  return tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
216  rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
217  currentTile, memrefIndices, tileSliceIndex,
218  tileLoadOp.getLayout());
219  });
220 
221  if (failed(forOp))
222  return forOp;
223 
224  // Replace 'arm_sme.tile_load' with the result.
225  rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
226 
227  return success();
228  }
229 };
230 
231 /// Lower `arm_sme.tile_load` with mask and non-zero pad.
232 ///
233 /// BEFORE:
234 /// ```mlir
235 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
236 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
237 /// memref<?x?xi32>, vector<[4]x[4]xi32>
238 /// ```
239 ///
240 /// AFTER:
241 /// ```mlir
242 /// ...
243 /// %pad_1d = vector.splat %pad : vector<[4]xi32>
244 /// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
245 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
246 /// ...
247 /// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
248 /// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
249 /// : memref<?x?xi32>, vector<[4]xi1>,
250 /// vector<[4]xi32> into vector<[4]xi32>
251 /// // Insert slice into tile
252 /// %tile_update = arm_sme.move_vector_to_tile_slice
253 /// %slice, %iter_tile, %tile_slice_idx :
254 /// vector<[4]xi32> into vector<[4]x[4]xi32>
255 /// scf.yield %tile_update : vector<[4]x[4]xi32>
256 /// }
257 /// ```
258 struct TileLoadOpWithMaskAndPadNonZeroConversion
259  : public OpRewritePattern<arm_sme::TileLoadOp> {
261 
262  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
263  PatternRewriter &rewriter) const override {
264  OpBuilder::InsertionGuard g(rewriter);
265  auto loc = tileLoadOp.getLoc();
266  auto tileType = tileLoadOp.getVectorType();
267  auto tileElementType = tileType.getElementType();
268 
269  auto maskOp = tileLoadOp.getMask();
270  if (!maskOp)
271  return rewriter.notifyMatchFailure(
272  tileLoadOp, "op has no mask, needs unmasked pattern");
273 
274  auto padOp = tileLoadOp.getPadding();
275  assert(padOp && "expected padding when masking!");
276 
277  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
278  if (!createMaskOp)
279  return rewriter.notifyMatchFailure(
280  tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
281  "currently supported");
282 
283  auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
284  if (constPadOp &&
285  constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
286  return rewriter.notifyMatchFailure(
287  tileLoadOp, "op has constant zero pad, needs zero pad pattern");
288 
289  auto numRows = createMaskOp.getOperands()[0];
290  auto numCols = createMaskOp.getOperands()[1];
291 
292  auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
293  loc, rewriter.getI32Type(), numCols);
294 
295  // Allocate a new SME tile.
296  auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
297  rewriter, loc, tileType);
298 
299  // Create a loop that loads each ZA tile slice from memory.
300  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
301  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
302  loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
303  auto vscale =
304  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
305  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
306  auto numTileSlices =
307  rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
308  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
309  step, ValueRange{initTile});
310 
311  rewriter.setInsertionPointToStart(forOp.getBody());
312 
313  auto tileSliceIndex = forOp.getInductionVar();
314  auto currentTile = forOp.getRegionIterArg(0);
315 
316  // Combine masks.
317  auto rowIsActive = rewriter.create<arith::CmpIOp>(
318  loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
319  auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
320  loc, rewriter.getI32Type(), rowIsActive);
321  auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
322  auto maskIndex =
323  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
324  auto predicateType =
325  VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
326  auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
327  loc, predicateType, maskIndex.getResult());
328 
329  auto memrefIndices = getMemrefIndices(
330  tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
331  tileSliceIndex, numTileSlices, loc, rewriter);
332 
333  // Splat pad into 1-D vector matching type of tile slice.
334  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
335  auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
336 
337  auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
338  loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
339  /*passthru=*/pad1DOp);
340 
341  // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
342  auto moveSlice =
343  tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
344  rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
345  tileSliceIndex, tileLoadOp.getLayout());
346  rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());
347 
348  rewriter.setInsertionPointAfter(forOp);
349 
350  // Replace 'arm_sme.tile_load' with the result.
351  rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
352 
353  return success();
354  }
355 };
356 
357 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
358 /// slice using `arm_sme.store_tile_slice`.
359 ///
360 /// BEFORE:
361 /// ```mlir
362 /// arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
363 /// : memref<?x?xi32>, vector<[4]x[4]xi32
364 /// ```
365 ///
366 /// AFTER:
367 /// ```mlir
368 /// %vscale = vector.vscale
369 /// %c0 = arith.constant 0 : index
370 /// %c1 = arith.constant 1 : index
371 /// %min_svl_s = arith.constant 4 : index
372 /// %svl_s = arith.muli %min_svl_s, %vscale : index
373 /// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
374 /// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
375 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
376 /// }
377 /// ```
378 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
380 
381  LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
382  PatternRewriter &rewriter) const override {
383  // Create a loop that stores each active ZA tile slice from memory.
384  return createLoadStoreForOverTileSlices(
385  rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
386  tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
387  tileStoreOp.getMask(),
388  [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
389  tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
390  rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
391  predicate, tileStoreOp.getBase(), memrefIndices,
392  tileStoreOp.getLayout());
393  });
394  }
395 };
396 
397 } // namespace
398 
400  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
401  TileStoreOpConversion>(patterns.getContext());
402 }
403 
404 namespace {
405 
406 struct ConvertArmSMEToSCFPass
407  : public impl::ConvertArmSMEToSCFBase<ConvertArmSMEToSCFPass> {
408  void runOnOperation() override {
409  RewritePatternSet patterns(&getContext());
410  ConversionTarget target(getContext());
412  target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
413  arith::ArithDialect, scf::SCFDialect>();
414  target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
415  if (failed(applyPartialConversion(getOperation(), target,
416  std::move(patterns))))
417  signalPassFailure();
418  }
419 };
420 
421 } // namespace
422 
423 std::unique_ptr<Pass> mlir::createConvertArmSMEToSCFPass() {
424  return std::make_unique<ConvertArmSMEToSCFPass>();
425 }
static MLIRContext * getContext(OpFoldResult val)
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:330
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToSCFPass()
Create a pass to convert a subset of ArmSME ops to SCF.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358