MLIR  20.0.0git
ArmSMEToSCF.cpp
Go to the documentation of this file.
1 //===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTARMSMETOSCF
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 /// Returns adjusted (1-D or 2-D) `indices` for a tile slice as follows:
30 /// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
31 /// rank 2: (indices[0] + tileSliceIndex, indices[1])
32 SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
33  Value tileSliceIndex,
34  Value tileSliceNumElts, Location loc,
35  PatternRewriter &rewriter) {
36  assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
37  SmallVector<Value, 2> outIndices;
38 
39  auto tileSliceOffset = tileSliceIndex;
40  if (rank == 1)
41  tileSliceOffset =
42  rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
43 
44  auto baseIndexPlusTileSliceOffset =
45  rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
46  outIndices.push_back(baseIndexPlusTileSliceOffset);
47 
48  if (rank == 2)
49  outIndices.push_back(indices[1]);
50 
51  return outIndices;
52 }
53 
54 /// Creates an scf.for for the load/store of an ArmSME tile.
55 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
56  PatternRewriter &rewriter, Location loc, VectorType tileType,
57  ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
58  function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
59  /*currentTile=*/Value)>
60  makeLoopBody) {
61  PatternRewriter::InsertionGuard guard(rewriter);
62 
63  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
64  loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
65  auto vscale =
66  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
67  auto predicateType =
68  VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
69 
70  // This describes both the number of ZA tile slices and the number of
71  // elements in a vector of SVL bits for a given element type (SVL_B,
72  // SVL_H, ..., SVL_Q).
73  auto numTileSlices =
74  rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
75 
76  Value predicate;
77  Value upperBound;
78  if (mask) {
79  auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
80  if (!createMaskOp)
81  return rewriter.notifyMatchFailure(
82  loc, "unsupported mask op, only 'vector.create_mask' is "
83  "currently supported");
84 
85  auto maskDim0 = createMaskOp.getOperands()[0];
86  auto maskDim1 = createMaskOp.getOperands()[1];
87 
88  // The upper bound of the loop must be clamped at `numTileSlices` as
89  // `vector.create_mask` allows operands to be greater than the size of a
90  // dimension.
91  auto numRowI64 = rewriter.create<arith::IndexCastOp>(
92  loc, rewriter.getI64Type(), maskDim0);
93  auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
94  loc, rewriter.getI64Type(), numTileSlices);
95  auto upperBoundI64 =
96  rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
97  upperBound = rewriter.create<arith::IndexCastOp>(
98  loc, rewriter.getIndexType(), upperBoundI64);
99 
100  predicate =
101  rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
102  } else {
103  upperBound = numTileSlices;
104  // No mask. Create an 'all true' predicate for the tile slice.
105  predicate = rewriter.create<arith::ConstantOp>(
106  loc, DenseElementsAttr::get(predicateType, true));
107  }
108 
109  bool hasCarriedArgs = bool(initTile);
110  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
111  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
112  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
113  hasCarriedArgs ? ValueRange{initTile}
114  : ValueRange{});
115 
116  rewriter.setInsertionPointToStart(forOp.getBody());
117  Value tileSliceIndex = forOp.getInductionVar();
118 
119  auto adjustedIndices = getMemrefIndices(
120  memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
121  auto nextTile = makeLoopBody(
122  tileSliceIndex, adjustedIndices, predicate,
123  /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
124 
125  assert(bool(nextTile) == hasCarriedArgs);
126  if (nextTile)
127  rewriter.create<scf::YieldOp>(loc, nextTile);
128 
129  return forOp;
130 }
131 
132 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
133  PatternRewriter &rewriter, Location loc, VectorType tileType,
134  ValueRange memrefIndices, int memrefRank, Value mask,
135  function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
136  makeLoopBody) {
137  return createLoadStoreForOverTileSlices(
138  rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
139  [&](Value index, ValueRange adjustedIndices, Value predicate,
140  Value) -> Value {
141  makeLoopBody(index, adjustedIndices, predicate);
142  return {};
143  });
144 }
145 
146 /// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
147 ///
148 /// With a mask:
149 ///
150 /// BEFORE:
151 /// ```mlir
152 /// %pad = arith.constant 0 : i32
153 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
154 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
155 /// memref<?x?xi32>, vector<[4]x[4]xi32>
156 /// ```
157 ///
158 /// AFTER:
159 /// ```mlir
160 /// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
161 /// %mask_cols = vector.create_mask %num_cols : vector<[4]xi1>
162 /// %loop_rows = arith.minsi %num_rows, %svl_s : index
163 /// %tile = scf.for %tile_slice_idx = %c0 to %loop_rows step %c1
164 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
165 /// %tile_update = arm_sme.load_tile_slice
166 /// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
167 /// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
168 /// scf.yield %tile_update : vector<[4]x[4]xi32>
169 /// }
170 /// ```
171 ///
172 /// Without a mask the lowering is pretty much identical. The only difference is
173 /// %mask_cols becomes an all-true mask, and %loop_rows becomes %svl_s.
174 ///
175 /// NOTE: Only mask of 'vector.create_mask' op is currently supported.
176 struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
178 
179  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
180  PatternRewriter &rewriter) const override {
181  auto loc = tileLoadOp.getLoc();
182  auto tileType = tileLoadOp.getVectorType();
183  auto mask = tileLoadOp.getMask();
184 
185  Value initTile;
186  if (mask) {
187  auto padOp = tileLoadOp.getPadding();
188  assert(padOp && "expected padding when masking!");
189 
190  auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
191  if (!constPadOp || constPadOp.getValue() !=
192  rewriter.getZeroAttr(tileType.getElementType()))
193  return rewriter.notifyMatchFailure(
194  tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
195 
196  // Initialize tile with zero to satisfy padding. Inactive cols will be
197  // zeroed anyway since the loads use zeroing predication. For inactive
198  // rows however, no load will occur so these need to be zeroed.
199  initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
200  } else {
201  initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
202  }
203 
204  // Create a loop to load the active tile slices from memory.
205  auto forOp = createLoadStoreForOverTileSlices(
206  rewriter, loc, tileType, tileLoadOp.getIndices(),
207  tileLoadOp.getMemRefType().getRank(), mask, initTile,
208  [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
209  Value currentTile) -> Value {
210  // Create 'arm_sme.load_tile_slice' to load tile slice from memory
211  // into tile.
212  return rewriter.create<arm_sme::LoadTileSliceOp>(
213  loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
214  memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
215  });
216 
217  if (failed(forOp))
218  return forOp;
219 
220  // Replace 'arm_sme.tile_load' with the result.
221  rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
222 
223  return success();
224  }
225 };
226 
227 /// Lower `arm_sme.tile_load` with mask and non-zero pad.
228 ///
229 /// BEFORE:
230 /// ```mlir
231 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
232 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
233 /// memref<?x?xi32>, vector<[4]x[4]xi32>
234 /// ```
235 ///
236 /// AFTER:
237 /// ```mlir
238 /// ...
239 /// %pad_1d = vector.splat %pad : vector<[4]xi32>
240 /// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
241 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
242 /// ...
243 /// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
244 /// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
245 /// : memref<?x?xi32>, vector<[4]xi1>,
246 /// vector<[4]xi32> into vector<[4]xi32>
247 /// // Insert slice into tile
248 /// %tile_update = arm_sme.insert_tile_slice
249 /// %slice, %iter_tile[%tile_slice_idx] :
250 /// vector<[4]xi32> into vector<[4]x[4]xi32>
251 /// scf.yield %tile_update : vector<[4]x[4]xi32>
252 /// }
253 /// ```
254 struct TileLoadOpWithMaskAndPadNonZeroConversion
255  : public OpRewritePattern<arm_sme::TileLoadOp> {
257 
258  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
259  PatternRewriter &rewriter) const override {
260  OpBuilder::InsertionGuard g(rewriter);
261  auto loc = tileLoadOp.getLoc();
262  auto tileType = tileLoadOp.getVectorType();
263  auto tileElementType = tileType.getElementType();
264 
265  auto maskOp = tileLoadOp.getMask();
266  if (!maskOp)
267  return rewriter.notifyMatchFailure(
268  tileLoadOp, "op has no mask, needs unmasked pattern");
269 
270  auto padOp = tileLoadOp.getPadding();
271  assert(padOp && "expected padding when masking!");
272 
273  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
274  if (!createMaskOp)
275  return rewriter.notifyMatchFailure(
276  tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
277  "currently supported");
278 
279  auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
280  if (constPadOp &&
281  constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
282  return rewriter.notifyMatchFailure(
283  tileLoadOp, "op has constant zero pad, needs zero pad pattern");
284 
285  auto numRows = createMaskOp.getOperands()[0];
286  auto numCols = createMaskOp.getOperands()[1];
287 
288  auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
289  loc, rewriter.getI32Type(), numCols);
290 
291  auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
292 
293  // Create a loop that loads each ZA tile slice from memory.
294  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
295  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
296  loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
297  auto vscale =
298  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
299  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
300  auto numTileSlices =
301  rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
302  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
303  step, ValueRange{initTile});
304 
305  rewriter.setInsertionPointToStart(forOp.getBody());
306 
307  auto tileSliceIndex = forOp.getInductionVar();
308  auto currentTile = forOp.getRegionIterArg(0);
309 
310  // Combine masks.
311  auto rowIsActive = rewriter.create<arith::CmpIOp>(
312  loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
313  auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
314  loc, rewriter.getI32Type(), rowIsActive);
315  auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
316  auto maskIndex =
317  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
318  auto predicateType =
319  VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
320  auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
321  loc, predicateType, maskIndex.getResult());
322 
323  auto memrefIndices = getMemrefIndices(
324  tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
325  tileSliceIndex, numTileSlices, loc, rewriter);
326 
327  // Splat pad into 1-D vector matching type of tile slice.
328  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
329  auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
330 
331  auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
332  loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
333  /*passthru=*/pad1DOp);
334 
335  // Create 'arm_sme.insert_tile_slice' to insert slice into tile.
336  auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>(
337  loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
338  tileLoadOp.getLayout());
339  rewriter.create<scf::YieldOp>(loc, insertSlice.getResult());
340 
341  rewriter.setInsertionPointAfter(forOp);
342 
343  // Replace 'arm_sme.tile_load' with the result.
344  rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
345 
346  return success();
347  }
348 };
349 
350 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
351 /// slice using `arm_sme.store_tile_slice`.
352 ///
353 /// BEFORE:
354 /// ```mlir
355 /// arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
356 /// : memref<?x?xi32>, vector<[4]x[4]xi32
357 /// ```
358 ///
359 /// AFTER:
360 /// ```mlir
361 /// %vscale = vector.vscale
362 /// %c0 = arith.constant 0 : index
363 /// %c1 = arith.constant 1 : index
364 /// %min_svl_s = arith.constant 4 : index
365 /// %svl_s = arith.muli %min_svl_s, %vscale : index
366 /// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
367 /// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
368 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
369 /// }
370 /// ```
371 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
373 
374  LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
375  PatternRewriter &rewriter) const override {
376  // Create a loop that stores each active ZA tile slice from memory.
377  return createLoadStoreForOverTileSlices(
378  rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
379  tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
380  tileStoreOp.getMask(),
381  [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
382  rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
383  tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
384  predicate, tileStoreOp.getBase(), memrefIndices,
385  tileStoreOp.getLayout());
386  });
387  }
388 };
389 
390 } // namespace
391 
393  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
394  TileStoreOpConversion>(patterns.getContext());
395 }
396 
397 namespace {
398 
399 struct ConvertArmSMEToSCFPass
400  : public impl::ConvertArmSMEToSCFBase<ConvertArmSMEToSCFPass> {
401  void runOnOperation() override {
402  RewritePatternSet patterns(&getContext());
403  ConversionTarget target(getContext());
405  target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
406  arith::ArithDialect, scf::SCFDialect>();
407  target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
408  if (failed(applyPartialConversion(getOperation(), target,
409  std::move(patterns))))
410  signalPassFailure();
411  }
412 };
413 
414 } // namespace
415 
416 std::unique_ptr<Pass> mlir::createConvertArmSMEToSCFPass() {
417  return std::make_unique<ConvertArmSMEToSCFPass>();
418 }
static MLIRContext * getContext(OpFoldResult val)
IntegerType getI64Type()
Definition: Builders.cpp:97
IntegerType getI32Type()
Definition: Builders.cpp:95
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:343
IntegerType getI1Type()
Definition: Builders.cpp:85
IndexType getIndexType()
Definition: Builders.cpp:83
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:353
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:436
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:417
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:314
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:339
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToSCFPass()
Create a pass to convert a subset of ArmSME ops to SCF.
void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the ArmSME dialect to SCF.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358