MLIR  20.0.0git
VectorLegalization.cpp
Go to the documentation of this file.
1 //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass legalizes vector operations so they can be lowered to ArmSME.
10 //
11 // Note: In the context of this pass 'tile' always refers to an SME tile.
12 //
13 //===----------------------------------------------------------------------===//
14 
29 
30 #define DEBUG_TYPE "arm-sme-vector-legalization"
31 
32 namespace mlir::arm_sme {
33 #define GEN_PASS_DEF_VECTORLEGALIZATION
34 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
35 } // namespace mlir::arm_sme
36 
37 using namespace mlir;
38 using namespace mlir::arm_sme;
39 
40 namespace {
41 
42 //===----------------------------------------------------------------------===//
43 // Decomposition of vector operations larger than an SME tile
44 //===----------------------------------------------------------------------===//
45 
46 // Common match failure reasons.
47 static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
48  "op vector size is not multiple of SME tiles");
49 static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
50  "op mask is unsupported for legalization/decomposition");
51 static constexpr StringLiteral
52  kMatchFailureNonPermutationMap("op affine map is not a permutation");
53 static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
54  "expected transpose from illegal type to legal type");
55 
56 /// An SMESubTile represents a single SME-sized sub-tile from decomposing a
57 /// larger vector type. The (`row`, `col`) are the position of the tile in the
58 /// original vector type. For example for an [8]x[8] tile with four [4]x[4]
59 /// sub-tiles, we would have:
60 ///
61 /// 8 x vscale
62 /// ┌─────────────┬─────────────┐
63 /// │(0,0) │(0,4) │
64 /// │ │ │
65 /// ├─────────────┼─────────────┤ 8 x vscale
66 /// │(4,0) │(4,4) │
67 /// │ │ │
68 /// └─────────────┴─────────────┘
69 struct SMESubTile {
70  // Note: The units of (row, col) are vscale (as SME tiles are scalable).
71  int row{0};
72  int col{0};
73  // The SME tile type.
74  VectorType type;
75 };
76 
77 /// Adds a constant elementwise scalable offset to `indices` (which are of equal
78 /// length). For example, in the 2D case this would return:
79 // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
80 SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
81  Location loc,
82  ValueRange indices,
83  ArrayRef<int> scalableOffsets) {
84  auto vscale = builder.create<vector::VectorScaleOp>(loc);
85  return llvm::map_to_vector(
86  llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
87  auto [index, base] = pair;
88  auto offset = builder.create<arith::MulIOp>(
89  loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
90  return builder.create<arith::AddIOp>(loc, index, offset);
91  });
92 }
93 
94 /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
95 /// indices for one of the SME sub-tiles it will decompose into.
96 ///
97 /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
98 /// indices for each tile would need to be adjusted as follows:
99 ///
100 /// initial indices = [a,b], inital size = 8x8, target size = 4x4
101 /// ┌─────────────┬─────────────┐
102 /// │[a,b] │[a,b+4] │
103 /// │ │ │
104 /// ├─────────────┼─────────────┤
105 /// │[a+4,b] │[a+4,b+4] │
106 /// │ │ │
107 /// └─────────────┴─────────────┘
108 SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
109  ValueRange indices,
110  SMESubTile smeTile) {
111  return addConstantScalableOffset(builder, loc, indices,
112  {smeTile.row, smeTile.col});
113 }
114 
115 /// Returns true if `mask` is generated by an operation that can be decomposed
116 /// for SME. Currently, that is just no mask, or vector.create_mask.
117 /// TODO: Add support for vector.constant_mask once required for SME.
118 bool isSupportedMaskOp(Value mask) {
119  return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
120 }
121 
122 /// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
123 Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
124  SMESubTile smeTile) {
125  assert(isSupportedMaskOp(mask));
126  if (!mask)
127  return Value{};
128  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
129  // The operands of `vector.create_mask` (from a 2D perspective) are the
130  // coordinates where the mask ends. So we subtract where this tile starts,
131  // from the mask operands to get the parameters for this sub-tile.
132  auto smeTileMaskDims = addConstantScalableOffset(
133  builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
134  auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
135  loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
136  return smeTileCreateMask.getResult();
137 }
138 
139 /// Constructs an iterator that returns each SME tile (with coordinates)
140 /// contained within a VectorType. For example, if decomposing an [8]x[8] into
141 /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
142 /// (4, 4).
143 auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
144  VectorType smeTileType,
145  bool transposeIndices = false) {
146  return llvm::map_range(
148  type.getShape(),
149  {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
150  std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
151  [=](auto indices) {
152  int row = int(indices[0]);
153  int col = int(indices[1]);
154  if (transposeIndices)
155  std::swap(row, col);
156  return SMESubTile{row, col, smeTileType};
157  });
158 }
159 
160 /// Returns the number of SME tiles that fit into the (2D-scalable) vector type
161 /// `type`.
162 int getNumberOfSMETilesForVectorType(VectorType type) {
163  assert(isMultipleOfSMETileVectorType(type) &&
164  "`type` not multiple of SME tiles");
165  int64_t vectorRows = type.getDimSize(0);
166  int64_t vectorCols = type.getDimSize(1);
167  auto elementType = type.getElementType();
168  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
169  return (vectorRows * vectorCols) / (minNumElts * minNumElts);
170 }
171 
172 /// Legalize `arith.constant dense<value>` splat operations to fit within SME
173 /// tiles by decomposing them into tile-sized operations.
174 struct LegalizeArithConstantOpsByDecomposition
175  : public OneToNOpConversionPattern<arith::ConstantOp> {
177 
178  LogicalResult
179  matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
180  OneToNPatternRewriter &rewriter) const override {
181  auto vectorType = dyn_cast<VectorType>(constantOp.getType());
182  auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
183  if (!vectorType || !denseAttr || !denseAttr.isSplat())
184  return failure();
185 
186  if (!isMultipleOfSMETileVectorType(vectorType))
187  return rewriter.notifyMatchFailure(constantOp,
188  kMatchFailureNotSMETileTypeMultiple);
189 
190  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
191  auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
192  auto tileSplat = rewriter.create<arith::ConstantOp>(
193  constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
194  rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
195  adaptor.getResultMapping());
196 
197  return success();
198  }
199 };
200 
201 /// Legalize `vector.outerproduct` operations to fit within SME tiles by
202 /// decomposing them into tile-sized operations.
203 struct LegalizeVectorOuterProductOpsByDecomposition
204  : public OneToNOpConversionPattern<vector::OuterProductOp> {
206 
207  LogicalResult
208  matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
209  OneToNPatternRewriter &rewriter) const override {
210  auto vectorType = outerProductOp.getResultVectorType();
211  if (!isMultipleOfSMETileVectorType(vectorType))
212  return rewriter.notifyMatchFailure(outerProductOp,
213  kMatchFailureNotSMETileTypeMultiple);
214 
215  Value mask;
216  Operation *rootOp = outerProductOp;
217  auto loc = outerProductOp.getLoc();
218  if (outerProductOp.isMasked()) {
219  auto maskOp = outerProductOp.getMaskingOp();
220  mask = maskOp.getMask();
221  rootOp = maskOp;
222  }
223 
224  if (!isSupportedMaskOp(mask))
225  return rewriter.notifyMatchFailure(outerProductOp,
226  kMatchFailureUnsupportedMaskOp);
227 
228  ValueRange accSMETiles = adaptor.getAcc();
229  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
230  VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
231 
232  SmallVector<Value> resultSMETiles;
233  for (auto [index, smeTile] : llvm::enumerate(
234  decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
235 
236  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
237  auto lhs = rewriter.create<vector::ScalableExtractOp>(
238  loc, sliceType, outerProductOp.getLhs(), smeTile.row);
239  auto rhs = rewriter.create<vector::ScalableExtractOp>(
240  loc, sliceType, outerProductOp.getRhs(), smeTile.col);
241  auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
242  loc, smeTileType, lhs, rhs,
243  !accSMETiles.empty() ? accSMETiles[index] : Value{},
244  outerProductOp.getKind());
245 
246  auto maskedOuterProduct =
247  vector::maskOperation(rewriter, smeOuterProduct, smeMask);
248  resultSMETiles.push_back(maskedOuterProduct->getResult(0));
249  }
250 
251  rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
252  return success();
253  }
254 };
255 
256 // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
257 // get the help of the type conversion), but doing so results in the type
258 // conversion adding target materializations in the `vector.mask` region
259 // (invalid). This pattern matches on `vector.mask` then calls into the
260 // `vector.outerproduct` pattern to work around this issue.
261 struct LegalizeMaskedVectorOuterProductOpsByDecomposition
262  : public OneToNOpConversionPattern<vector::MaskOp> {
264 
265  LogicalResult
266  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
267  OneToNPatternRewriter &rewriter) const override {
268  if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
269  maskOp.getMaskableOp())) {
270  LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
271  getContext());
272  return static_cast<RewritePattern &>(pattern).matchAndRewrite(
273  outerProductOp, rewriter);
274  }
275  return failure();
276  }
277 };
278 
279 /// Legalize `vector.transfer_read` operations to fit within SME tiles by
280 /// decomposing them into tile-sized operations.
281 struct LegalizeTransferReadOpsByDecomposition
282  : public OneToNOpConversionPattern<vector::TransferReadOp> {
284 
285  LogicalResult
286  matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
287  OneToNPatternRewriter &rewriter) const override {
288  auto vectorType = readOp.getVectorType();
289  if (!isMultipleOfSMETileVectorType(vectorType))
290  return rewriter.notifyMatchFailure(readOp,
291  kMatchFailureNotSMETileTypeMultiple);
292 
293  auto mask = readOp.getMask();
294  if (!isSupportedMaskOp(mask))
295  return rewriter.notifyMatchFailure(readOp,
296  kMatchFailureUnsupportedMaskOp);
297 
298  auto permutationMap = readOp.getPermutationMap();
299  if (!permutationMap.isPermutation())
300  return rewriter.notifyMatchFailure(readOp,
301  kMatchFailureNonPermutationMap);
302 
303  // Note: For 2D vector types the only non-identity permutation is a simple
304  // tranpose [1, 0].
305  bool transposed = !permutationMap.isIdentity();
306 
307  auto loc = readOp.getLoc();
308  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
309 
310  SmallVector<Value> resultSMETiles;
311  for (SMESubTile smeTile :
312  decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
313  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
314  auto smeRead = rewriter.create<vector::TransferReadOp>(
315  loc, smeTileType, readOp.getSource(),
316  getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
317  readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
318  readOp.getInBoundsAttr());
319  resultSMETiles.push_back(smeRead);
320  }
321 
322  rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
323  return success();
324  }
325 };
326 
327 /// Legalize `vector.transfer_write` operations to fit within SME tiles by
328 /// decomposing them into tile-sized operations.
329 struct LegalizeTransferWriteOpsByDecomposition
330  : public OneToNOpConversionPattern<vector::TransferWriteOp> {
332 
333  LogicalResult
334  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
335  OneToNPatternRewriter &rewriter) const override {
336  auto vectorType = writeOp.getVectorType();
337  if (!isMultipleOfSMETileVectorType(vectorType))
338  return rewriter.notifyMatchFailure(writeOp,
339  kMatchFailureNotSMETileTypeMultiple);
340 
341  auto mask = writeOp.getMask();
342  if (!isSupportedMaskOp(mask))
343  return rewriter.notifyMatchFailure(writeOp,
344  kMatchFailureUnsupportedMaskOp);
345 
346  auto permutationMap = writeOp.getPermutationMap();
347  if (!permutationMap.isPermutation())
348  return rewriter.notifyMatchFailure(writeOp,
349  kMatchFailureNonPermutationMap);
350 
351  // Note: For 2D vector types the only non-identity permutation is a simple
352  // tranpose [1, 0].
353  bool transposed = !permutationMap.isIdentity();
354 
355  auto loc = writeOp.getLoc();
356  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
357  auto inputSMETiles = adaptor.getVector();
358 
359  Value destTensorOrMemref = writeOp.getSource();
360  for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
361  rewriter, vectorType, smeTileType, transposed))) {
362  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
363  auto smeWrite = rewriter.create<vector::TransferWriteOp>(
364  loc, inputSMETiles[index], destTensorOrMemref,
365  getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
366  writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
367  if (writeOp.hasPureTensorSemantics())
368  destTensorOrMemref = smeWrite.getResult();
369  }
370 
371  if (writeOp.hasPureTensorSemantics())
372  rewriter.replaceOp(writeOp, destTensorOrMemref);
373  else
374  rewriter.eraseOp(writeOp);
375 
376  return success();
377  }
378 };
379 
380 /// Legalize a multi-tile transfer_write as a single store loop. This is done as
381 /// part of type decomposition as at this level we know each tile write is
382 /// disjoint, but that information is lost after decomposition (without analysis
383 /// to reconstruct it).
384 ///
385 /// Example (pseudo-MLIR):
386 ///
387 /// ```
388 /// vector.transfer_write %vector, %dest[%y, %x], %mask
389 /// : vector<[16]x[8]xi16>, memref<?x?xi16>
390 /// ```
391 /// Is rewritten to:
392 /// ```
393 /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
394 /// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
395 /// : vector<[8]xi1> from vector<[16]x[8]xi1> |
396 /// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
397 /// : vector<[8]xi16> from vector<[8]x[8]xi16> |
398 /// vector.transfer_write %upper_slice, |
399 /// %dest[%slice_idx + %y, %x], %upper_slice_mask |
400 /// : vector<[8]xi16>, memref<?x?xi16> ┘
401 /// %lower_slice_idx = %slice_idx + %c8_vscale ─┐
402 /// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] |
403 /// : vector<[8]xi1> from vector<[16]x[8]xi1> |
404 /// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower
405 /// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile
406 /// vector.transfer_write %lower_slice, |
407 /// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask |
408 /// : vector<[8]xi16>, memref<?x?xi16> ┘
409 /// }
410 /// ```
411 struct LegalizeMultiTileTransferWriteAsStoreLoop
412  : public OneToNOpConversionPattern<vector::TransferWriteOp> {
414 
415  LogicalResult
416  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
417  OneToNPatternRewriter &rewriter) const override {
418  if (writeOp.hasPureTensorSemantics())
419  return rewriter.notifyMatchFailure(
420  writeOp, "TODO: tensor semantics are unsupported");
421 
422  auto permutationMap = writeOp.getPermutationMap();
423  if (!permutationMap.isPermutation())
424  return rewriter.notifyMatchFailure(writeOp,
425  kMatchFailureNonPermutationMap);
426 
427  bool transposed = !permutationMap.isIdentity();
428  if (transposed)
429  return rewriter.notifyMatchFailure(writeOp,
430  "TODO: transpose unsupported");
431 
432  auto vectorType = writeOp.getVectorType();
433  if (!isMultipleOfSMETileVectorType(vectorType))
434  return rewriter.notifyMatchFailure(writeOp,
435  kMatchFailureNotSMETileTypeMultiple);
436 
437  // Note: We also disallow masks where any dimension is > 16 because that
438  // prevents the masking from being lowered to use arm_sve.psel.
439  auto mask = writeOp.getMask();
440  if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
441  vectorType.getDimSize(1) > 16)))
442  return rewriter.notifyMatchFailure(writeOp,
443  kMatchFailureUnsupportedMaskOp);
444 
445  auto loc = writeOp.getLoc();
446  auto createVscaleMultiple =
447  vector::makeVscaleConstantBuilder(rewriter, loc);
448 
449  // Get SME tile and slice types.
450  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
451  auto minTileSlices = smeTileType.getDimSize(0);
452  VectorType sliceMaskType =
453  VectorType::get(minTileSlices, rewriter.getI1Type(), true);
454 
455  // Create loop over all tile slices.
456  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
457  auto upperBound = createVscaleMultiple(minTileSlices);
458  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
459  auto storeLoop =
460  rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
461  rewriter.setInsertionPointToStart(storeLoop.getBody());
462 
463  // For each sub-tile of the multi-tile `vectorType`.
464  auto inputSMETiles = adaptor.getVector();
465  auto tileSliceIndex = storeLoop.getInductionVar();
466  for (auto [index, smeTile] : llvm::enumerate(
467  decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
468  // The coordinates of the tile within `vectorType`.
469  auto tileRow = createVscaleMultiple(smeTile.row);
470  auto tileCol = createVscaleMultiple(smeTile.col);
471 
472  // The current slice of `vectorType` we are processing.
473  auto sliceIndex =
474  rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
475 
476  // Where in the destination memref the current slice will be stored.
477  auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
478  writeOp.getIndices()[0]);
479  auto storeCol =
480  rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
481 
482  // Extract the mask for the current slice.
483  Value sliceMask = nullptr;
484  if (mask) {
485  sliceMask = rewriter.create<vector::ExtractOp>(
486  loc, mask, OpFoldResult(sliceIndex));
487  if (sliceMaskType != sliceMask.getType())
488  sliceMask = rewriter.create<vector::ScalableExtractOp>(
489  loc, sliceMaskType, sliceMask, smeTile.col);
490  }
491 
492  // Extract and store the current slice.
493  Value tile = inputSMETiles[index];
494  auto slice =
495  rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
496  rewriter.create<vector::TransferWriteOp>(
497  loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
498  AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
499  sliceMask,
500  rewriter.getBoolArrayAttr(
501  ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
502  }
503 
504  rewriter.eraseOp(writeOp);
505  return success();
506  }
507 };
508 
509 //===----------------------------------------------------------------------===//
510 // ArmSME-specific fixup canonicalizations/folds
511 //===----------------------------------------------------------------------===//
512 
513 /// Folds an extract from a 3D `vector.create_mask` (which is a vector of
514 /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
515 /// necessary for the mask to be lowered to ArmSME.
516 ///
517 /// Example:
518 ///
519 /// BEFORE:
520 /// ```mlir
521 /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
522 /// %subMask = vector.extract %mask[2]
523 /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
524 /// ```
525 ///
526 /// AFTER:
527 /// ```mlir
528 /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
529 /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
530 /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
531 /// ```
532 struct FoldExtractFromVectorOfSMELikeCreateMasks
533  : public OpRewritePattern<vector::ExtractOp> {
535 
536  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
537  PatternRewriter &rewriter) const override {
538  auto loc = extractOp.getLoc();
539  auto createMaskOp =
540  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
541  if (!createMaskOp)
542  return rewriter.notifyMatchFailure(
543  extractOp, "extract not from vector.create_mask op");
544 
545  VectorType extractedMaskType =
546  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
547  if (!extractedMaskType)
548  return rewriter.notifyMatchFailure(extractOp,
549  "extracted type is not a vector type");
550 
551  auto numScalable = extractedMaskType.getNumScalableDims();
552  if (numScalable != 2)
553  return rewriter.notifyMatchFailure(
554  extractOp, "expected extracted type to be an SME-like mask");
555 
556  // TODO: Support multiple extraction indices.
557  if (extractOp.getStaticPosition().size() != 1)
558  return rewriter.notifyMatchFailure(
559  extractOp, "only a single extraction index is supported");
560 
561  auto frontMaskDim = createMaskOp.getOperand(0);
562  if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
563  return rewriter.notifyMatchFailure(
564  extractOp,
565  "constant vector.create_masks dims should be folded elsewhere");
566 
567  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
568  auto extractionIndex = getValueOrCreateConstantIndexOp(
569  rewriter, loc, extractOp.getMixedPosition()[0]);
570  auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
571  loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
572  frontMaskDim);
573  auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
574  loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
575 
576  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
577  extractOp, extractedMaskType,
578  ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
579  return success();
580  }
581 };
582 
583 /// A vector type where no fixed dimension comes after a scalable dimension.
584 bool isLegalVectorType(VectorType vType) {
585  bool seenFixedDim = false;
586  for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
587  seenFixedDim |= !scalableFlag;
588  if (seenFixedDim && scalableFlag)
589  return false;
590  }
591  return true;
592 }
593 
594 /// Lifts an illegal vector.transpose and vector.transfer_read to a
595 /// memref.subview + memref.transpose, followed by a legal read.
596 ///
597 /// 'Illegal' here means a leading scalable dimension and a fixed trailing
598 /// dimension, which has no valid lowering.
599 ///
600 /// The memref.transpose is metadata-only transpose that produces a strided
601 /// memref, which eventually becomes a loop reading individual elements.
602 ///
603 /// Example:
604 ///
605 /// BEFORE:
606 /// ```mlir
607 /// %illegalRead = vector.transfer_read %memref[%a, %b]
608 /// : memref<?x?xf32>, vector<[8]x4xf32>
609 /// %legalType = vector.transpose %illegalRead, [1, 0]
610 /// : vector<[8]x4xf32> to vector<4x[8]xf32>
611 /// ```
612 ///
613 /// AFTER:
614 /// ```mlir
615 /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
616 /// : memref<?x?xf32> to memref<?x?xf32>
617 /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
618 /// : memref<?x?xf32> to memref<?x?xf32>
619 /// %legalType = vector.transfer_read %transpose[%c0, %c0]
620 /// : memref<?x?xf32>, vector<4x[8]xf32>
621 /// ```
622 struct LiftIllegalVectorTransposeToMemory
623  : public OpRewritePattern<vector::TransposeOp> {
625 
626  static Value getExtensionSource(Operation *op) {
627  if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
628  return op->getOperand(0);
629  return {};
630  }
631 
632  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
633  PatternRewriter &rewriter) const override {
634  auto sourceType = transposeOp.getSourceVectorType();
635  auto resultType = transposeOp.getResultVectorType();
636  if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
637  return rewriter.notifyMatchFailure(transposeOp,
638  kMatchFailureNotIllegalToLegal);
639 
640  // Look through extend for transfer_read.
641  Value maybeRead = transposeOp.getVector();
642  auto *transposeSourceOp = maybeRead.getDefiningOp();
643  Operation *extendOp = nullptr;
644  if (Value extendSource = getExtensionSource(transposeSourceOp)) {
645  maybeRead = extendSource;
646  extendOp = transposeSourceOp;
647  }
648 
649  auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
650  if (!illegalRead)
651  return rewriter.notifyMatchFailure(
652  transposeOp,
653  "expected source to be (possibly extended) transfer_read");
654 
655  if (!illegalRead.getPermutationMap().isIdentity())
656  return rewriter.notifyMatchFailure(
657  illegalRead, "expected read to have identity permutation map");
658 
659  auto loc = transposeOp.getLoc();
660  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
661  auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
662 
663  // Create a subview that matches the size of the illegal read vector type.
664  auto readType = illegalRead.getVectorType();
665  auto readSizes = llvm::map_to_vector(
666  llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
667  [&](auto dim) -> Value {
668  auto [size, isScalable] = dim;
669  auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
670  if (!isScalable)
671  return dimSize;
672  auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
673  return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
674  });
675  SmallVector<Value> strides(readType.getRank(), Value(one));
676  auto readSubview = rewriter.create<memref::SubViewOp>(
677  loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
678  strides);
679 
680  // Apply the transpose to all values/attributes of the transfer_read:
681  // - The mask
682  Value mask = illegalRead.getMask();
683  if (mask) {
684  // Note: The transpose for the mask should fold into the
685  // vector.create_mask/constant_mask op, which will then become legal.
686  mask = rewriter.create<vector::TransposeOp>(loc, mask,
687  transposeOp.getPermutation());
688  }
689  // - The source memref
691  transposeOp.getPermutation(), getContext());
692  auto transposedSubview = rewriter.create<memref::TransposeOp>(
693  loc, readSubview, AffineMapAttr::get(transposeMap));
694  ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
695  // - The `in_bounds` attribute
696  if (inBoundsAttr) {
697  SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
698  inBoundsAttr.end());
699  applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
700  inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
701  }
702 
703  VectorType legalReadType = resultType.clone(readType.getElementType());
704  // Note: The indices are all zero as the subview is already offset.
705  SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
706  auto legalRead = rewriter.create<vector::TransferReadOp>(
707  loc, legalReadType, transposedSubview, readIndices,
708  illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
709  inBoundsAttr);
710 
711  // Replace the transpose with the new read, extending the result if
712  // necessary.
713  rewriter.replaceOp(transposeOp, [&]() -> Operation * {
714  if (extendOp)
715  return rewriter.create(loc, extendOp->getName().getIdentifier(),
716  Value(legalRead), resultType);
717  return legalRead;
718  }());
719 
720  return success();
721  }
722 };
723 
724 /// A rewrite to turn unit dim transpose-like vector.shape_casts into
725 /// vector.transposes. The shape_cast has to be from an illegal vector type to a
726 /// legal one (as defined by isLegalVectorType).
727 ///
728 /// The reasoning for this is if we've got to this pass and we still have
729 /// shape_casts of illegal types, then they likely will not cancel out. Turning
730 /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
731 /// eliminate them.
732 ///
733 /// Example:
734 ///
735 /// BEFORE:
736 /// ```mlir
737 /// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
738 /// ```
739 ///
740 /// AFTER:
741 /// ```mlir
742 /// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
743 /// ```
744 struct ConvertIllegalShapeCastOpsToTransposes
745  : public OpRewritePattern<vector::ShapeCastOp> {
747 
748  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
749  PatternRewriter &rewriter) const override {
750  auto sourceType = shapeCastOp.getSourceVectorType();
751  auto resultType = shapeCastOp.getResultVectorType();
752  if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
753  return rewriter.notifyMatchFailure(shapeCastOp,
754  kMatchFailureNotIllegalToLegal);
755 
756  // Note: If we know that `sourceType` is an illegal vector type (and 2D)
757  // then dim 0 is scalable and dim 1 is fixed.
758  if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
759  return rewriter.notifyMatchFailure(
760  shapeCastOp, "expected source to be a 2D scalable vector with a "
761  "trailing unit dim");
762 
763  auto loc = shapeCastOp.getLoc();
764  auto transpose = rewriter.create<vector::TransposeOp>(
765  loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
766 
767  if (resultType.getRank() == 1)
768  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
769  transpose);
770  else
771  rewriter.replaceOp(shapeCastOp, transpose);
772 
773  return success();
774  }
775 };
776 
777 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
778 /// the ZA state. This workaround rewrite to support these transposes when ZA is
779 /// available.
780 ///
781 /// Example:
782 ///
783 /// BEFORE:
784 /// ```mlir
785 /// %transpose = vector.transpose %vec, [1, 0]
786 /// : vector<2x[4]xf32> to vector<[4]x2xf32>
787 /// vector.transfer_write %transpose, %dest[%y, %x]
788 /// : vector<[4]x2xf32>, memref<?x?xf32>
789 /// ```
790 ///
791 /// AFTER:
792 /// ```mlir
793 /// %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
794 /// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
795 /// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
796 /// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
797 /// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
798 /// %c4_vscale = arith.muli %vscale, %c4 : index
799 /// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
800 /// vector.transfer_write %4, %dest[%y, %x], %mask
801 /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
802 /// : vector<[4]x[4]xf32>, memref<?x?xf32>
803 /// ```
804 ///
805 /// Values larger than a single tile are supported via decomposition.
806 struct LowerIllegalTransposeStoreViaZA
807  : public OpRewritePattern<vector::TransferWriteOp> {
809 
810  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
811  PatternRewriter &rewriter) const override {
812  if (!isSupportedMaskOp(writeOp.getMask()))
813  return rewriter.notifyMatchFailure(writeOp,
814  kMatchFailureUnsupportedMaskOp);
815 
816  auto permutationMap = writeOp.getPermutationMap();
817  if (!permutationMap.isIdentity())
818  return rewriter.notifyMatchFailure(writeOp,
819  kMatchFailureNonPermutationMap);
820 
821  auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
822  if (!transposeOp)
823  return failure();
824 
825  auto sourceType = transposeOp.getSourceVectorType();
826  auto resultType = transposeOp.getResultVectorType();
827 
828  if (resultType.getRank() != 2)
829  return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");
830 
831  if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
832  return rewriter.notifyMatchFailure(
833  transposeOp, "not illegal/unsupported SVE transpose");
834 
835  auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
836  VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
837 
838  if (sourceType.getDimSize(0) <= 1 ||
839  sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
840  return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
841 
842  auto loc = writeOp.getLoc();
843  auto createVscaleMultiple =
844  vector::makeVscaleConstantBuilder(rewriter, loc);
845 
846  auto transposeMap = AffineMapAttr::get(
848 
849  // Note: We need to use `get_tile` as there's no vector-level `undef`.
850  Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
851  Value destTensorOrMemref = writeOp.getSource();
852  auto numSlicesPerTile =
853  std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
854  auto numSlices =
855  rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
856  for (auto [index, smeTile] : llvm::enumerate(
857  decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
858  // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
859  // of slices from the source type into the SME tile. Without checking
860  // vscale (and emitting multiple implementations) we can't make use of the
861  // rows of the tile after 1*vscale rows.
862  Value tile = undefTile;
863  for (int d = 0; d < numSlicesPerTile; ++d) {
864  Value vector = rewriter.create<vector::ExtractOp>(
865  loc, transposeOp.getVector(),
866  rewriter.getIndexAttr(d + smeTile.row));
867  if (vector.getType() != smeSliceType) {
868  vector = rewriter.create<vector::ScalableExtractOp>(
869  loc, smeSliceType, vector, smeTile.col);
870  }
871  tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
872  }
873 
874  // 2. Transpose the tile position.
875  auto transposedRow = createVscaleMultiple(smeTile.col);
876  auto transposedCol =
877  rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);
878 
879  // 3. Compute mask for tile store.
880  Value maskRows;
881  Value maskCols;
882  if (auto mask = writeOp.getMask()) {
883  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
884  maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
885  transposedRow);
886  maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
887  transposedCol);
888  maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
889  } else {
890  maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
891  maskCols = numSlices;
892  }
893  auto subMask = rewriter.create<vector::CreateMaskOp>(
894  loc, smeTileType.clone(rewriter.getI1Type()),
895  ValueRange{maskRows, maskCols});
896 
897  // 4. Emit a transposed tile write.
898  auto writeIndices = writeOp.getIndices();
899  Value destRow =
900  rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
901  Value destCol =
902  rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
903  auto smeWrite = rewriter.create<vector::TransferWriteOp>(
904  loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
905  transposeMap, subMask, writeOp.getInBounds());
906 
907  if (writeOp.hasPureTensorSemantics())
908  destTensorOrMemref = smeWrite.getResult();
909  }
910 
911  if (writeOp.hasPureTensorSemantics())
912  rewriter.replaceOp(writeOp, destTensorOrMemref);
913  else
914  rewriter.eraseOp(writeOp);
915 
916  return success();
917  }
918 };
919 
920 struct VectorLegalizationPass
921  : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
922  void runOnOperation() override {
923  auto *context = &getContext();
924  TypeConverter converter;
925  RewritePatternSet patterns(context);
926  converter.addConversion([](Type type) { return type; });
927  converter.addConversion(
928  [](VectorType vectorType,
929  SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
930  if (!isMultipleOfSMETileVectorType(vectorType))
931  return std::nullopt;
932  auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
933  auto smeTileType =
934  getSMETileTypeForElement(vectorType.getElementType());
935  types = SmallVector<Type>(smeTileCount, smeTileType);
936  return success();
937  });
938 
939  patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
940  LiftIllegalVectorTransposeToMemory,
941  ConvertIllegalShapeCastOpsToTransposes,
942  LowerIllegalTransposeStoreViaZA>(context);
943  // Note: These two patterns are added with a high benefit to ensure:
944  // - Masked outer products are handled before unmasked ones
945  // - Multi-tile writes are lowered as a store loop (if possible)
946  patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
947  LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
948  /*benefit=*/1024);
949  patterns.add<LegalizeArithConstantOpsByDecomposition,
950  LegalizeVectorOuterProductOpsByDecomposition,
951  LegalizeTransferReadOpsByDecomposition,
952  LegalizeTransferWriteOpsByDecomposition>(converter, context);
955 
956  if (failed(applyPartialOneToNConversion(getOperation(), converter,
957  std::move(patterns))))
958  return signalPassFailure();
959  }
960 };
961 
962 } // namespace
963 
965  return std::make_unique<VectorLegalizationPass>();
966 }
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value createMask(AffineForOp vecForOp, VectorizationState &state)
Creates a mask used to filter out garbage elements in the last iteration of unaligned loops.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerType getI1Type()
Definition: Builders.cpp:97
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:310
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class is a wrapper around OneToNConversionPattern for matching against instances of a particular...
OneToNOpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Specialization of PatternRewriter that OneToNConversionPatterns use.
void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping)
Replaces the results of the operation with the specified list of values mapped back to the original t...
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:317
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:342
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
Definition: Utils.cpp:114
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
std::unique_ptr< Pass > createVectorLegalizationPass()
Pass that legalizes vectors so they can be lowered to ArmSME.
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
Definition: Utils.cpp:97
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateSCFStructuralOneToNTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the provided pattern set with patterns that do 1:N type conversions on (some) SCF ops.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Definition: VectorUtils.h:113
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1282
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362