MLIR  22.0.0git
VectorLegalization.cpp
Go to the documentation of this file.
1 //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass legalizes vector operations so they can be lowered to ArmSME.
10 //
11 // Note: In the context of this pass 'tile' always refers to an SME tile.
12 //
13 //===----------------------------------------------------------------------===//
14 
30 
31 #define DEBUG_TYPE "arm-sme-vector-legalization"
32 
33 namespace mlir::arm_sme {
34 #define GEN_PASS_DEF_VECTORLEGALIZATION
35 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
36 } // namespace mlir::arm_sme
37 
38 using namespace mlir;
39 using namespace mlir::arm_sme;
40 
41 namespace {
42 
43 //===----------------------------------------------------------------------===//
44 // Decomposition of vector operations larger than an SME tile
45 //===----------------------------------------------------------------------===//
46 
47 // Common match failure reasons.
48 static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
49  "op vector size is not multiple of SME tiles");
50 static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
51  "op mask is unsupported for legalization/decomposition");
52 static constexpr StringLiteral
53  kMatchFailureNonPermutationMap("op affine map is not a permutation");
54 static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
55  "expected transpose from illegal type to legal type");
56 
57 /// An SMESubTile represents a single SME-sized sub-tile from decomposing a
58 /// larger vector type. The (`row`, `col`) are the position of the tile in the
59 /// original vector type. For example for an [8]x[8] tile with four [4]x[4]
60 /// sub-tiles, we would have:
61 ///
62 /// 8 x vscale
63 /// ┌─────────────┬─────────────┐
64 /// │(0,0) │(0,4) │
65 /// │ │ │
66 /// ├─────────────┼─────────────┤ 8 x vscale
67 /// │(4,0) │(4,4) │
68 /// │ │ │
69 /// └─────────────┴─────────────┘
70 struct SMESubTile {
71  // Note: The units of (row, col) are vscale (as SME tiles are scalable).
72  int row{0};
73  int col{0};
74  // The SME tile type.
75  VectorType type;
76 };
77 
78 /// Adds a constant elementwise scalable offset to `indices` (which are of equal
79 /// length). For example, in the 2D case this would return:
80 // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
81 SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
82  Location loc,
83  ValueRange indices,
84  ArrayRef<int> scalableOffsets) {
85  auto vscale = vector::VectorScaleOp::create(builder, loc);
86  return llvm::map_to_vector(
87  llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
88  auto [index, base] = pair;
89  auto offset = arith::MulIOp::create(
90  builder, loc, arith::ConstantIndexOp::create(builder, loc, base),
91  vscale);
92  return arith::AddIOp::create(builder, loc, index, offset);
93  });
94 }
95 
96 /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
97 /// indices for one of the SME sub-tiles it will decompose into.
98 ///
99 /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
100 /// indices for each tile would need to be adjusted as follows:
101 ///
102 /// initial indices = [a,b], inital size = 8x8, target size = 4x4
103 /// ┌─────────────┬─────────────┐
104 /// │[a,b] │[a,b+4] │
105 /// │ │ │
106 /// ├─────────────┼─────────────┤
107 /// │[a+4,b] │[a+4,b+4] │
108 /// │ │ │
109 /// └─────────────┴─────────────┘
110 SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
111  ValueRange indices,
112  SMESubTile smeTile) {
113  return addConstantScalableOffset(builder, loc, indices,
114  {smeTile.row, smeTile.col});
115 }
116 
117 /// Returns true if `mask` is generated by an operation that can be decomposed
118 /// for SME. Currently, that is just no mask, or vector.create_mask.
119 /// TODO: Add support for vector.constant_mask once required for SME.
120 bool isSupportedMaskOp(Value mask) {
121  return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
122 }
123 
124 /// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
125 Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
126  SMESubTile smeTile) {
127  assert(isSupportedMaskOp(mask));
128  if (!mask)
129  return Value{};
130  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
131  // The operands of `vector.create_mask` (from a 2D perspective) are the
132  // coordinates where the mask ends. So we subtract where this tile starts,
133  // from the mask operands to get the parameters for this sub-tile.
134  auto smeTileMaskDims = addConstantScalableOffset(
135  builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
136  auto smeTileCreateMask = vector::CreateMaskOp::create(
137  builder, loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
138  return smeTileCreateMask.getResult();
139 }
140 
141 /// Constructs an iterator that returns each SME tile (with coordinates)
142 /// contained within a VectorType. For example, if decomposing an [8]x[8] into
143 /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
144 /// (4, 4).
145 auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
146  VectorType smeTileType,
147  bool transposeIndices = false) {
148  return llvm::map_range(
150  type.getShape(),
151  {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
152  std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
153  [=](auto indices) {
154  int row = int(indices[0]);
155  int col = int(indices[1]);
156  if (transposeIndices)
157  std::swap(row, col);
158  return SMESubTile{row, col, smeTileType};
159  });
160 }
161 
162 /// Returns the number of SME tiles that fit into the (2D-scalable) vector type
163 /// `type`.
164 int getNumberOfSMETilesForVectorType(VectorType type) {
165  assert(isMultipleOfSMETileVectorType(type) &&
166  "`type` not multiple of SME tiles");
167  int64_t vectorRows = type.getDimSize(0);
168  int64_t vectorCols = type.getDimSize(1);
169  auto elementType = type.getElementType();
170  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
171  return (vectorRows * vectorCols) / (minNumElts * minNumElts);
172 }
173 
174 /// Legalize `arith.constant dense<value>` splat operations to fit within SME
175 /// tiles by decomposing them into tile-sized operations.
176 struct LegalizeArithConstantOpsByDecomposition
177  : public OpConversionPattern<arith::ConstantOp> {
179 
180  LogicalResult
181  matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
182  ConversionPatternRewriter &rewriter) const override {
183  auto vectorType = dyn_cast<VectorType>(constantOp.getType());
184  auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
185  if (!vectorType || !denseAttr || !denseAttr.isSplat())
186  return failure();
187 
188  if (!isMultipleOfSMETileVectorType(vectorType))
189  return rewriter.notifyMatchFailure(constantOp,
190  kMatchFailureNotSMETileTypeMultiple);
191 
192  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
193  auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
194  auto tileSplat = arith::ConstantOp::create(
195  rewriter, constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
196  SmallVector<Value> repl(tileCount, tileSplat);
197  rewriter.replaceOpWithMultiple(constantOp, {repl});
198 
199  return success();
200  }
201 };
202 
203 /// Legalize `vector.outerproduct` operations to fit within SME tiles by
204 /// decomposing them into tile-sized operations.
205 struct LegalizeVectorOuterProductOpsByDecomposition
206  : public OpConversionPattern<vector::OuterProductOp> {
208 
209  LogicalResult
210  matchAndRewrite(vector::OuterProductOp outerProductOp,
211  OneToNOpAdaptor adaptor,
212  ConversionPatternRewriter &rewriter) const override {
213  auto vectorType = outerProductOp.getResultVectorType();
214  if (!isMultipleOfSMETileVectorType(vectorType))
215  return rewriter.notifyMatchFailure(outerProductOp,
216  kMatchFailureNotSMETileTypeMultiple);
217 
218  Value mask;
219  Operation *rootOp = outerProductOp;
220  auto loc = outerProductOp.getLoc();
221  if (outerProductOp.isMasked()) {
222  auto maskOp = outerProductOp.getMaskingOp();
223  mask = maskOp.getMask();
224  rootOp = maskOp;
225  rewriter.setInsertionPoint(rootOp);
226  }
227 
228  if (!isSupportedMaskOp(mask))
229  return rewriter.notifyMatchFailure(outerProductOp,
230  kMatchFailureUnsupportedMaskOp);
231 
232  ValueRange accSMETiles = adaptor.getAcc();
233  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
234  VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
235 
236  SmallVector<Value> resultSMETiles;
237  for (auto [index, smeTile] : llvm::enumerate(
238  decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
239 
240  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
241  auto lhs = vector::ScalableExtractOp::create(
242  rewriter, loc, sliceType, outerProductOp.getLhs(), smeTile.row);
243  auto rhs = vector::ScalableExtractOp::create(
244  rewriter, loc, sliceType, outerProductOp.getRhs(), smeTile.col);
245  auto smeOuterProduct = vector::OuterProductOp::create(
246  rewriter, loc, smeTileType, lhs, rhs,
247  !accSMETiles.empty() ? accSMETiles[index] : Value{},
248  outerProductOp.getKind());
249 
250  auto maskedOuterProduct =
251  vector::maskOperation(rewriter, smeOuterProduct, smeMask);
252  resultSMETiles.push_back(maskedOuterProduct->getResult(0));
253  }
254 
255  rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
256  return success();
257  }
258 };
259 
260 // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
261 // get the help of the type conversion), but doing so results in the type
262 // conversion adding target materializations in the `vector.mask` region
263 // (invalid). This pattern matches on `vector.mask` then calls into the
264 // `vector.outerproduct` pattern to work around this issue.
265 struct LegalizeMaskedVectorOuterProductOpsByDecomposition
266  : public OpConversionPattern<vector::MaskOp> {
268 
269  LogicalResult
270  matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
271  ConversionPatternRewriter &rewriter) const override {
272  if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
273  maskOp.getMaskableOp())) {
274  LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
275  getContext());
276  return static_cast<RewritePattern &>(pattern).matchAndRewrite(
277  outerProductOp, rewriter);
278  }
279  return failure();
280  }
281 };
282 
283 /// Legalize `vector.transfer_read` operations to fit within SME tiles by
284 /// decomposing them into tile-sized operations.
285 struct LegalizeTransferReadOpsByDecomposition
286  : public OpConversionPattern<vector::TransferReadOp> {
288 
289  LogicalResult
290  matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
291  ConversionPatternRewriter &rewriter) const override {
292  auto vectorType = readOp.getVectorType();
293  if (!isMultipleOfSMETileVectorType(vectorType))
294  return rewriter.notifyMatchFailure(readOp,
295  kMatchFailureNotSMETileTypeMultiple);
296 
297  auto mask = readOp.getMask();
298  if (!isSupportedMaskOp(mask))
299  return rewriter.notifyMatchFailure(readOp,
300  kMatchFailureUnsupportedMaskOp);
301 
302  auto permutationMap = readOp.getPermutationMap();
303  if (!permutationMap.isPermutation())
304  return rewriter.notifyMatchFailure(readOp,
305  kMatchFailureNonPermutationMap);
306 
307  // Note: For 2D vector types the only non-identity permutation is a simple
308  // transpose [1, 0].
309  bool transposed = !permutationMap.isIdentity();
310 
311  auto loc = readOp.getLoc();
312  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
313 
314  SmallVector<Value> resultSMETiles;
315  for (SMESubTile smeTile :
316  decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
317  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
318  auto smeRead = vector::TransferReadOp::create(
319  rewriter, loc, smeTileType, readOp.getBase(),
320  getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
321  readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
322  readOp.getInBoundsAttr());
323  resultSMETiles.push_back(smeRead);
324  }
325 
326  rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
327  return success();
328  }
329 };
330 
331 /// Legalize `vector.transfer_write` operations to fit within SME tiles by
332 /// decomposing them into tile-sized operations.
333 struct LegalizeTransferWriteOpsByDecomposition
334  : public OpConversionPattern<vector::TransferWriteOp> {
336 
337  LogicalResult
338  matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
339  ConversionPatternRewriter &rewriter) const override {
340  auto vectorType = writeOp.getVectorType();
341  if (!isMultipleOfSMETileVectorType(vectorType))
342  return rewriter.notifyMatchFailure(writeOp,
343  kMatchFailureNotSMETileTypeMultiple);
344 
345  auto mask = writeOp.getMask();
346  if (!isSupportedMaskOp(mask))
347  return rewriter.notifyMatchFailure(writeOp,
348  kMatchFailureUnsupportedMaskOp);
349 
350  auto permutationMap = writeOp.getPermutationMap();
351  if (!permutationMap.isPermutation())
352  return rewriter.notifyMatchFailure(writeOp,
353  kMatchFailureNonPermutationMap);
354 
355  // Note: For 2D vector types the only non-identity permutation is a simple
356  // transpose [1, 0].
357  bool transposed = !permutationMap.isIdentity();
358 
359  auto loc = writeOp.getLoc();
360  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
361  auto inputSMETiles = adaptor.getValueToStore();
362 
363  Value destTensorOrMemref = writeOp.getBase();
364  for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
365  rewriter, vectorType, smeTileType, transposed))) {
366  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
367  auto smeWrite = vector::TransferWriteOp::create(
368  rewriter, loc, inputSMETiles[index], destTensorOrMemref,
369  getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
370  writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
371  if (writeOp.hasPureTensorSemantics())
372  destTensorOrMemref = smeWrite.getResult();
373  }
374 
375  if (writeOp.hasPureTensorSemantics())
376  rewriter.replaceOp(writeOp, destTensorOrMemref);
377  else
378  rewriter.eraseOp(writeOp);
379 
380  return success();
381  }
382 };
383 
384 /// Legalize a multi-tile transfer_write as a single store loop. This is done as
385 /// part of type decomposition as at this level we know each tile write is
386 /// disjoint, but that information is lost after decomposition (without analysis
387 /// to reconstruct it).
388 ///
389 /// Example (pseudo-MLIR):
390 ///
391 /// ```
392 /// vector.transfer_write %vector, %dest[%y, %x], %mask
393 /// : vector<[16]x[8]xi16>, memref<?x?xi16>
394 /// ```
395 /// Is rewritten to:
396 /// ```
397 /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
398 /// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
399 /// : vector<[8]xi1> from vector<[16]x[8]xi1> |
400 /// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
401 /// : vector<[8]xi16> from vector<[8]x[8]xi16> |
402 /// vector.transfer_write %upper_slice, |
403 /// %dest[%slice_idx + %y, %x], %upper_slice_mask |
404 /// : vector<[8]xi16>, memref<?x?xi16> ┘
405 /// %lower_slice_idx = %slice_idx + %c8_vscale ─┐
406 /// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] |
407 /// : vector<[8]xi1> from vector<[16]x[8]xi1> |
408 /// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower
409 /// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile
410 /// vector.transfer_write %lower_slice, |
411 /// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask |
412 /// : vector<[8]xi16>, memref<?x?xi16> ┘
413 /// }
414 /// ```
415 struct LegalizeMultiTileTransferWriteAsStoreLoop
416  : public OpConversionPattern<vector::TransferWriteOp> {
418 
419  LogicalResult
420  matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
421  ConversionPatternRewriter &rewriter) const override {
422  if (writeOp.hasPureTensorSemantics())
423  return rewriter.notifyMatchFailure(
424  writeOp, "TODO: tensor semantics are unsupported");
425 
426  auto permutationMap = writeOp.getPermutationMap();
427  if (!permutationMap.isPermutation())
428  return rewriter.notifyMatchFailure(writeOp,
429  kMatchFailureNonPermutationMap);
430 
431  bool transposed = !permutationMap.isIdentity();
432  if (transposed)
433  return rewriter.notifyMatchFailure(writeOp,
434  "TODO: transpose unsupported");
435 
436  auto vectorType = writeOp.getVectorType();
437  if (!isMultipleOfSMETileVectorType(vectorType))
438  return rewriter.notifyMatchFailure(writeOp,
439  kMatchFailureNotSMETileTypeMultiple);
440 
441  // Note: We also disallow masks where any dimension is > 16 because that
442  // prevents the masking from being lowered to use arm_sve.psel.
443  auto mask = writeOp.getMask();
444  if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
445  vectorType.getDimSize(1) > 16)))
446  return rewriter.notifyMatchFailure(writeOp,
447  kMatchFailureUnsupportedMaskOp);
448 
449  auto loc = writeOp.getLoc();
450  auto createVscaleMultiple =
451  vector::makeVscaleConstantBuilder(rewriter, loc);
452 
453  // Get SME tile and slice types.
454  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
455  auto minTileSlices = smeTileType.getDimSize(0);
456  VectorType sliceMaskType =
457  VectorType::get(minTileSlices, rewriter.getI1Type(), true);
458 
459  // Create loop over all tile slices.
460  auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
461  auto upperBound = createVscaleMultiple(minTileSlices);
462  auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
463  auto storeLoop =
464  scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
465  rewriter.setInsertionPointToStart(storeLoop.getBody());
466 
467  // For each sub-tile of the multi-tile `vectorType`.
468  auto inputSMETiles = adaptor.getValueToStore();
469  auto tileSliceIndex = storeLoop.getInductionVar();
470  for (auto [index, smeTile] : llvm::enumerate(
471  decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
472  // The coordinates of the tile within `vectorType`.
473  auto tileRow = createVscaleMultiple(smeTile.row);
474  auto tileCol = createVscaleMultiple(smeTile.col);
475 
476  // The current slice of `vectorType` we are processing.
477  auto sliceIndex =
478  arith::AddIOp::create(rewriter, loc, tileRow, tileSliceIndex);
479 
480  // Where in the destination memref the current slice will be stored.
481  auto storeRow = arith::AddIOp::create(rewriter, loc, sliceIndex,
482  writeOp.getIndices()[0]);
483  auto storeCol = arith::AddIOp::create(rewriter, loc, tileCol,
484  writeOp.getIndices()[1]);
485 
486  // Extract the mask for the current slice.
487  Value sliceMask = nullptr;
488  if (mask) {
489  sliceMask = vector::ExtractOp::create(rewriter, loc, mask,
490  OpFoldResult(sliceIndex));
491  if (sliceMaskType != sliceMask.getType())
492  sliceMask = vector::ScalableExtractOp::create(
493  rewriter, loc, sliceMaskType, sliceMask, smeTile.col);
494  }
495 
496  // Extract and store the current slice.
497  Value tile = inputSMETiles[index];
498  auto slice =
499  vector::ExtractOp::create(rewriter, loc, tile, tileSliceIndex);
500  vector::TransferWriteOp::create(
501  rewriter, loc, slice, writeOp.getBase(),
502  ValueRange{storeRow, storeCol},
503  AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
504  sliceMask,
505  rewriter.getBoolArrayAttr(
506  ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
507  }
508 
509  rewriter.eraseOp(writeOp);
510  return success();
511  }
512 };
513 
514 //===----------------------------------------------------------------------===//
515 // ArmSME-specific fixup canonicalizations/folds
516 //===----------------------------------------------------------------------===//
517 
518 /// Folds an extract from a 3D `vector.create_mask` (which is a vector of
519 /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
520 /// necessary for the mask to be lowered to ArmSME.
521 ///
522 /// Example:
523 ///
524 /// BEFORE:
525 /// ```mlir
526 /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
527 /// %subMask = vector.extract %mask[2]
528 /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
529 /// ```
530 ///
531 /// AFTER:
532 /// ```mlir
533 /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
534 /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
535 /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
536 /// ```
537 struct FoldExtractFromVectorOfSMELikeCreateMasks
538  : public OpRewritePattern<vector::ExtractOp> {
540 
541  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
542  PatternRewriter &rewriter) const override {
543  auto loc = extractOp.getLoc();
544  auto createMaskOp =
545  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
546  if (!createMaskOp)
547  return rewriter.notifyMatchFailure(
548  extractOp, "extract not from vector.create_mask op");
549 
550  VectorType extractedMaskType =
551  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
552  if (!extractedMaskType)
553  return rewriter.notifyMatchFailure(extractOp,
554  "extracted type is not a vector type");
555 
556  auto numScalable = extractedMaskType.getNumScalableDims();
557  if (numScalable != 2)
558  return rewriter.notifyMatchFailure(
559  extractOp, "expected extracted type to be an SME-like mask");
560 
561  // TODO: Support multiple extraction indices.
562  if (extractOp.getStaticPosition().size() != 1)
563  return rewriter.notifyMatchFailure(
564  extractOp, "only a single extraction index is supported");
565 
566  auto frontMaskDim = createMaskOp.getOperand(0);
567  if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
568  return rewriter.notifyMatchFailure(
569  extractOp,
570  "constant vector.create_masks dims should be folded elsewhere");
571 
572  auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
573  auto extractionIndex = getValueOrCreateConstantIndexOp(
574  rewriter, loc, extractOp.getMixedPosition()[0]);
575  auto extractionInTrueRegion = arith::CmpIOp::create(
576  rewriter, loc, rewriter.getI1Type(), arith::CmpIPredicate::slt,
577  extractionIndex, frontMaskDim);
578  auto newMaskFrontDim =
579  arith::SelectOp::create(rewriter, loc, extractionInTrueRegion,
580  createMaskOp.getOperand(1), zero);
581 
582  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
583  extractOp, extractedMaskType,
584  ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
585  return success();
586  }
587 };
588 
589 /// A vector type where no fixed dimension comes after a scalable dimension.
590 bool isLegalVectorType(VectorType vType) {
591  bool seenFixedDim = false;
592  for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
593  seenFixedDim |= !scalableFlag;
594  if (seenFixedDim && scalableFlag)
595  return false;
596  }
597  return true;
598 }
599 
600 /// Lifts an illegal vector.transpose and vector.transfer_read to a
601 /// memref.subview + memref.transpose, followed by a legal read.
602 ///
603 /// 'Illegal' here means a leading scalable dimension and a fixed trailing
604 /// dimension, which has no valid lowering.
605 ///
606 /// The memref.transpose is metadata-only transpose that produces a strided
607 /// memref, which eventually becomes a loop reading individual elements.
608 ///
609 /// Example:
610 ///
611 /// BEFORE:
612 /// ```mlir
613 /// %illegalRead = vector.transfer_read %memref[%a, %b]
614 /// : memref<?x?xf32>, vector<[8]x4xf32>
615 /// %legalType = vector.transpose %illegalRead, [1, 0]
616 /// : vector<[8]x4xf32> to vector<4x[8]xf32>
617 /// ```
618 ///
619 /// AFTER:
620 /// ```mlir
621 /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
622 /// : memref<?x?xf32> to memref<?x?xf32>
623 /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
624 /// : memref<?x?xf32> to memref<?x?xf32>
625 /// %legalType = vector.transfer_read %transpose[%c0, %c0]
626 /// : memref<?x?xf32>, vector<4x[8]xf32>
627 /// ```
628 struct LiftIllegalVectorTransposeToMemory
629  : public OpRewritePattern<vector::TransposeOp> {
631 
632  static Value getExtensionSource(Operation *op) {
633  if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
634  return op->getOperand(0);
635  return {};
636  }
637 
638  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
639  PatternRewriter &rewriter) const override {
640  auto sourceType = transposeOp.getSourceVectorType();
641  auto resultType = transposeOp.getResultVectorType();
642  if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
643  return rewriter.notifyMatchFailure(transposeOp,
644  kMatchFailureNotIllegalToLegal);
645 
646  // Look through extend for transfer_read.
647  Value maybeRead = transposeOp.getVector();
648  auto *transposeSourceOp = maybeRead.getDefiningOp();
649  Operation *extendOp = nullptr;
650  if (Value extendSource = getExtensionSource(transposeSourceOp)) {
651  maybeRead = extendSource;
652  extendOp = transposeSourceOp;
653  }
654 
655  auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
656  if (!illegalRead)
657  return rewriter.notifyMatchFailure(
658  transposeOp,
659  "expected source to be (possibly extended) transfer_read");
660 
661  if (!illegalRead.getPermutationMap().isIdentity())
662  return rewriter.notifyMatchFailure(
663  illegalRead, "expected read to have identity permutation map");
664 
665  auto loc = transposeOp.getLoc();
666  auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
667  auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
668 
669  // Create a subview that matches the size of the illegal read vector type.
670  auto readType = illegalRead.getVectorType();
671  auto readSizes = llvm::map_to_vector(
672  llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
673  [&](auto dim) -> Value {
674  auto [size, isScalable] = dim;
675  auto dimSize = arith::ConstantIndexOp::create(rewriter, loc, size);
676  if (!isScalable)
677  return dimSize;
678  auto vscale = vector::VectorScaleOp::create(rewriter, loc);
679  return arith::MulIOp::create(rewriter, loc, vscale, dimSize);
680  });
681  SmallVector<Value> strides(readType.getRank(), Value(one));
682  auto readSubview =
683  memref::SubViewOp::create(rewriter, loc, illegalRead.getBase(),
684  illegalRead.getIndices(), readSizes, strides);
685 
686  // Apply the transpose to all values/attributes of the transfer_read:
687  // - The mask
688  Value mask = illegalRead.getMask();
689  if (mask) {
690  // Note: The transpose for the mask should fold into the
691  // vector.create_mask/constant_mask op, which will then become legal.
692  mask = vector::TransposeOp::create(rewriter, loc, mask,
693  transposeOp.getPermutation());
694  }
695  // - The source memref
697  transposeOp.getPermutation(), getContext());
698  auto transposedSubview = memref::TransposeOp::create(
699  rewriter, loc, readSubview, AffineMapAttr::get(transposeMap));
700  ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
701  // - The `in_bounds` attribute
702  if (inBoundsAttr) {
703  SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
704  inBoundsAttr.end());
705  applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
706  inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
707  }
708 
709  VectorType legalReadType = resultType.clone(readType.getElementType());
710  // Note: The indices are all zero as the subview is already offset.
711  SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
712  auto legalRead = vector::TransferReadOp::create(
713  rewriter, loc, legalReadType, transposedSubview, readIndices,
714  illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
715  inBoundsAttr);
716 
717  // Replace the transpose with the new read, extending the result if
718  // necessary.
719  rewriter.replaceOp(transposeOp, [&]() -> Operation * {
720  if (extendOp)
721  return rewriter.create(loc, extendOp->getName().getIdentifier(),
722  Value(legalRead), resultType);
723  return legalRead;
724  }());
725 
726  return success();
727  }
728 };
729 
730 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
731 /// the ZA state. This workaround rewrite to support these transposes when ZA is
732 /// available.
733 ///
734 /// Example:
735 ///
736 /// BEFORE:
737 /// ```mlir
738 /// %transpose = vector.transpose %vec, [1, 0]
739 /// : vector<2x[4]xf32> to vector<[4]x2xf32>
740 /// vector.transfer_write %transpose, %dest[%y, %x]
741 /// : vector<[4]x2xf32>, memref<?x?xf32>
742 /// ```
743 ///
744 /// AFTER:
745 /// ```mlir
746 /// %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
747 /// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
748 /// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
749 /// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
750 /// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
751 /// %c4_vscale = arith.muli %vscale, %c4 : index
752 /// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
753 /// vector.transfer_write %4, %dest[%y, %x], %mask
754 /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
755 /// : vector<[4]x[4]xf32>, memref<?x?xf32>
756 /// ```
757 ///
758 /// Values larger than a single tile are supported via decomposition.
759 struct LowerIllegalTransposeStoreViaZA
760  : public OpRewritePattern<vector::TransferWriteOp> {
762 
763  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
764  PatternRewriter &rewriter) const override {
765  if (!isSupportedMaskOp(writeOp.getMask()))
766  return rewriter.notifyMatchFailure(writeOp,
767  kMatchFailureUnsupportedMaskOp);
768 
769  auto permutationMap = writeOp.getPermutationMap();
770  if (!permutationMap.isIdentity())
771  return rewriter.notifyMatchFailure(writeOp,
772  kMatchFailureNonPermutationMap);
773 
774  auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
775  if (!transposeOp)
776  return failure();
777 
778  auto sourceType = transposeOp.getSourceVectorType();
779  auto resultType = transposeOp.getResultVectorType();
780 
781  if (resultType.getRank() != 2)
782  return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");
783 
784  if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
785  return rewriter.notifyMatchFailure(
786  transposeOp, "not illegal/unsupported SVE transpose");
787 
788  auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
789  VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
790 
791  if (sourceType.getDimSize(0) <= 1 ||
792  sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
793  return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
794 
795  auto loc = writeOp.getLoc();
796  auto createVscaleMultiple =
797  vector::makeVscaleConstantBuilder(rewriter, loc);
798 
799  auto transposeMap = AffineMapAttr::get(
801 
802  // Note: We need to use `get_tile` as there's no vector-level `undef`.
803  Value undefTile = arm_sme::GetTileOp::create(rewriter, loc, smeTileType);
804  Value destTensorOrMemref = writeOp.getBase();
805  auto numSlicesPerTile =
806  std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
807  auto numSlices =
808  arith::ConstantIndexOp::create(rewriter, loc, numSlicesPerTile);
809  for (auto [index, smeTile] : llvm::enumerate(
810  decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
811  // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
812  // of slices from the source type into the SME tile. Without checking
813  // vscale (and emitting multiple implementations) we can't make use of the
814  // rows of the tile after 1*vscale rows.
815  Value tile = undefTile;
816  for (int d = 0; d < numSlicesPerTile; ++d) {
817  Value vector =
818  vector::ExtractOp::create(rewriter, loc, transposeOp.getVector(),
819  rewriter.getIndexAttr(d + smeTile.row));
820  if (vector.getType() != smeSliceType) {
821  vector = vector::ScalableExtractOp::create(
822  rewriter, loc, smeSliceType, vector, smeTile.col);
823  }
824  tile = vector::InsertOp::create(rewriter, loc, vector, tile, d);
825  }
826 
827  // 2. Transpose the tile position.
828  auto transposedRow = createVscaleMultiple(smeTile.col);
829  auto transposedCol =
830  arith::ConstantIndexOp::create(rewriter, loc, smeTile.row);
831 
832  // 3. Compute mask for tile store.
833  Value maskRows;
834  Value maskCols;
835  if (auto mask = writeOp.getMask()) {
836  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
837  maskRows = arith::SubIOp::create(
838  rewriter, loc, createMask.getOperand(0), transposedRow);
839  maskCols = arith::SubIOp::create(
840  rewriter, loc, createMask.getOperand(1), transposedCol);
841  maskCols = index::MinSOp::create(rewriter, loc, maskCols, numSlices);
842  } else {
843  maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
844  maskCols = numSlices;
845  }
846  auto subMask = vector::CreateMaskOp::create(
847  rewriter, loc, smeTileType.clone(rewriter.getI1Type()),
848  ValueRange{maskRows, maskCols});
849 
850  // 4. Emit a transposed tile write.
851  auto writeIndices = writeOp.getIndices();
852  Value destRow =
853  arith::AddIOp::create(rewriter, loc, transposedRow, writeIndices[0]);
854  Value destCol =
855  arith::AddIOp::create(rewriter, loc, transposedCol, writeIndices[1]);
856  auto smeWrite = vector::TransferWriteOp::create(
857  rewriter, loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
858  transposeMap, subMask, writeOp.getInBounds());
859 
860  if (writeOp.hasPureTensorSemantics())
861  destTensorOrMemref = smeWrite.getResult();
862  }
863 
864  if (writeOp.hasPureTensorSemantics())
865  rewriter.replaceOp(writeOp, destTensorOrMemref);
866  else
867  rewriter.eraseOp(writeOp);
868 
869  return success();
870  }
871 };
872 
873 /// Lower `vector.transfer_read` of a scalable column to `scf::for`
874 ///
875 /// Lowers a "read" of a scalable column from a MemRef for which there is no
876 /// hardware pperation that we could use to a loop over the rows to read and
877 /// loads one element at a time.
878 ///
879 /// BEFORE:
880 /// ```
881 /// %res = vector.transfer_read %mem[%a, %b] (...)
882 /// : memref<?x?xf32>, vector<[4]x1xf32>
883 /// ```
884 ///
885 /// AFTER:
886 /// ```
887 /// %cst = arith.constant (...) : vector<[4]xf32>
888 /// %vscale = vector.vscale
889 /// %c4_vscale = arith.muli %vscale, %c4 : index
890 /// %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst)
891 /// -> (vector<[4]xf32>) {
892 ///
893 /// %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32>
894 /// %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32>
895 /// scf.yield %vec : vector<[4]xf32>
896 /// }
897 /// %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32>
898 /// ```
899 ///
900 /// TODO: This transformation isn't specific to SME - move it to the SVE
901 /// dialect.
902 /// TODO: Check the in_bounds attribute and generate vector.maskedload if
903 /// required.
904 struct LowerColumnTransferReadToLoops
905  : public OpRewritePattern<vector::TransferReadOp> {
907 
908  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
909  PatternRewriter &rewriter) const override {
910  // NOTE: This is a fairly low-level transformation, so we shouldn't be
911  // adding support for Tensors without good rationale.
912  if (readOp.hasPureTensorSemantics())
913  return rewriter.notifyMatchFailure(
914  readOp, "Tensor semantics are unsupported (either bufferize or "
915  "extend this pattern)");
916 
917  auto resType = readOp.getVectorType();
918 
919  if (resType.getRank() != 2)
920  return rewriter.notifyMatchFailure(readOp,
921  "Only 2D vectors are supported!");
922 
923  if (resType.getShape()[1] != 1)
924  return rewriter.notifyMatchFailure(
925  readOp, "The trailing output dim is != 1 (not supported ATM)");
926 
927  if (!resType.getScalableDims()[0] || resType.getScalableDims()[1])
928  return rewriter.notifyMatchFailure(
929  readOp, "Expected the leading dim to be scalable and the trailing "
930  "dim to be fixed.");
931 
932  // Create new result type - similar to the original vector with the
933  // trailing unit dim collapsed.
934  int64_t numRows = resType.getShape()[0];
935  VectorType newResType = VectorType::get(numRows, resType.getElementType(),
936  /*scalableDims=*/{true});
937 
938  // Create a loop over all rows and load one element at a time.
939  auto loc = readOp.getLoc();
940  auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
941  auto createVscaleMultiple =
942  vector::makeVscaleConstantBuilder(rewriter, loc);
943  auto upperBound = createVscaleMultiple(numRows);
944  auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
945  Value init = arith::ConstantOp::create(
946  rewriter, loc, newResType, DenseElementsAttr::get(newResType, 0.0f));
947 
948  scf::ForOp loadLoop;
949  {
950  OpBuilder::InsertionGuard g(rewriter);
951  loadLoop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
952  ValueRange{init});
953  rewriter.setInsertionPointToStart(loadLoop.getBody());
954 
955  auto tileSliceIndex = loadLoop.getInductionVar();
956 
957  auto idx0 = arith::AddIOp::create(rewriter, loc, tileSliceIndex,
958  readOp.getIndices()[0]);
959  auto idx1 = readOp.getIndices()[1];
960 
961  Value scalar = memref::LoadOp::create(rewriter, loc, readOp.getBase(),
962  SmallVector<Value>({idx0, idx1}));
963 
964  Operation *updateInit = vector::InsertOp::create(
965  rewriter, loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex);
966 
967  scf::YieldOp::create(rewriter, loc, updateInit->getResult(0));
968  }
969 
970  // The read operation has been "legalized", but since the original result
971  // type was a 2D vector, we need to cast before returning the result. This
972  // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
973  // no-op).
974  auto sc = vector::ShapeCastOp::create(
975  rewriter, loc, readOp.getResult().getType(), loadLoop.getResult(0));
976 
977  rewriter.replaceOp(readOp, sc);
978 
979  return success();
980  }
981 };
982 
983 struct VectorLegalizationPass
984  : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
985  void runOnOperation() override {
986  auto *context = &getContext();
987  TypeConverter converter;
988  RewritePatternSet patterns(context);
989  converter.addConversion([](Type type) { return type; });
990  converter.addConversion(
991  [](VectorType vectorType,
992  SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
993  if (!isMultipleOfSMETileVectorType(vectorType))
994  return std::nullopt;
995  auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
996  auto smeTileType =
997  getSMETileTypeForElement(vectorType.getElementType());
998  types = SmallVector<Type>(smeTileCount, smeTileType);
999  return success();
1000  });
1001 
1002  // Apply preprocessing patterns.
1003  RewritePatternSet rewritePatterns(context);
1004  rewritePatterns
1005  .add<FoldExtractFromVectorOfSMELikeCreateMasks,
1006  LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory,
1007  LowerIllegalTransposeStoreViaZA>(context);
1008  if (failed(
1009  applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
1010  return signalPassFailure();
1011 
1012  // Note: These two patterns are added with a high benefit to ensure:
1013  // - Masked outer products are handled before unmasked ones
1014  // - Multi-tile writes are lowered as a store loop (if possible)
1015  patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
1016  LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
1017  /*benefit=*/1024);
1018  patterns.add<LegalizeArithConstantOpsByDecomposition,
1019  LegalizeVectorOuterProductOpsByDecomposition,
1020  LegalizeTransferReadOpsByDecomposition,
1021  LegalizeTransferWriteOpsByDecomposition>(converter, context);
1022  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1023  converter);
1027 
1028  ConversionTarget target(getContext());
1029  target.markUnknownOpDynamicallyLegal(
1030  [&](Operation *op) { return converter.isLegal(op); });
1031  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1032  return converter.isSignatureLegal(op.getFunctionType());
1033  });
1034  if (failed(applyPartialConversion(getOperation(), target,
1035  std::move(patterns))))
1036  return signalPassFailure();
1037  }
1038 };
1039 
1040 } // namespace
1041 
1043  return std::make_unique<VectorLegalizationPass>();
1044 }
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value createMask(AffineForOp vecForOp, VectorizationState &state)
Creates a mask used to filter out garbage elements in the last iteration of unaligned loops.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IntegerType getI1Type()
Definition: Builders.cpp:52
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:265
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:311
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
Definition: Utils.cpp:113
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:17
std::unique_ptr< Pass > createVectorLegalizationPass()
Pass that legalizes vectors so they can be lowered to ArmSME.
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
Definition: Utils.cpp:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Definition: VectorUtils.h:117
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1282
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319