MLIR  21.0.0git
VectorLegalization.cpp
Go to the documentation of this file.
1 //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass legalizes vector operations so they can be lowered to ArmSME.
10 //
11 // Note: In the context of this pass 'tile' always refers to an SME tile.
12 //
13 //===----------------------------------------------------------------------===//
14 
30 
31 #define DEBUG_TYPE "arm-sme-vector-legalization"
32 
33 namespace mlir::arm_sme {
34 #define GEN_PASS_DEF_VECTORLEGALIZATION
35 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
36 } // namespace mlir::arm_sme
37 
38 using namespace mlir;
39 using namespace mlir::arm_sme;
40 
41 namespace {
42 
43 //===----------------------------------------------------------------------===//
44 // Decomposition of vector operations larger than an SME tile
45 //===----------------------------------------------------------------------===//
46 
47 // Common match failure reasons.
48 static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
49  "op vector size is not multiple of SME tiles");
50 static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
51  "op mask is unsupported for legalization/decomposition");
52 static constexpr StringLiteral
53  kMatchFailureNonPermutationMap("op affine map is not a permutation");
54 static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
55  "expected transpose from illegal type to legal type");
56 
57 /// An SMESubTile represents a single SME-sized sub-tile from decomposing a
58 /// larger vector type. The (`row`, `col`) are the position of the tile in the
59 /// original vector type. For example for an [8]x[8] tile with four [4]x[4]
60 /// sub-tiles, we would have:
61 ///
62 /// 8 x vscale
63 /// ┌─────────────┬─────────────┐
64 /// │(0,0) │(0,4) │
65 /// │ │ │
66 /// ├─────────────┼─────────────┤ 8 x vscale
67 /// │(4,0) │(4,4) │
68 /// │ │ │
69 /// └─────────────┴─────────────┘
70 struct SMESubTile {
71  // Note: The units of (row, col) are vscale (as SME tiles are scalable).
72  int row{0};
73  int col{0};
74  // The SME tile type.
75  VectorType type;
76 };
77 
78 /// Adds a constant elementwise scalable offset to `indices` (which are of equal
79 /// length). For example, in the 2D case this would return:
80 // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
81 SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
82  Location loc,
83  ValueRange indices,
84  ArrayRef<int> scalableOffsets) {
85  auto vscale = builder.create<vector::VectorScaleOp>(loc);
86  return llvm::map_to_vector(
87  llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
88  auto [index, base] = pair;
89  auto offset = builder.create<arith::MulIOp>(
90  loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
91  return builder.create<arith::AddIOp>(loc, index, offset);
92  });
93 }
94 
95 /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
96 /// indices for one of the SME sub-tiles it will decompose into.
97 ///
98 /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
99 /// indices for each tile would need to be adjusted as follows:
100 ///
101 /// initial indices = [a,b], inital size = 8x8, target size = 4x4
102 /// ┌─────────────┬─────────────┐
103 /// │[a,b] │[a,b+4] │
104 /// │ │ │
105 /// ├─────────────┼─────────────┤
106 /// │[a+4,b] │[a+4,b+4] │
107 /// │ │ │
108 /// └─────────────┴─────────────┘
109 SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
110  ValueRange indices,
111  SMESubTile smeTile) {
112  return addConstantScalableOffset(builder, loc, indices,
113  {smeTile.row, smeTile.col});
114 }
115 
116 /// Returns true if `mask` is generated by an operation that can be decomposed
117 /// for SME. Currently, that is just no mask, or vector.create_mask.
118 /// TODO: Add support for vector.constant_mask once required for SME.
119 bool isSupportedMaskOp(Value mask) {
120  return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
121 }
122 
123 /// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
124 Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
125  SMESubTile smeTile) {
126  assert(isSupportedMaskOp(mask));
127  if (!mask)
128  return Value{};
129  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
130  // The operands of `vector.create_mask` (from a 2D perspective) are the
131  // coordinates where the mask ends. So we subtract where this tile starts,
132  // from the mask operands to get the parameters for this sub-tile.
133  auto smeTileMaskDims = addConstantScalableOffset(
134  builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
135  auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
136  loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
137  return smeTileCreateMask.getResult();
138 }
139 
140 /// Constructs an iterator that returns each SME tile (with coordinates)
141 /// contained within a VectorType. For example, if decomposing an [8]x[8] into
142 /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
143 /// (4, 4).
144 auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
145  VectorType smeTileType,
146  bool transposeIndices = false) {
147  return llvm::map_range(
149  type.getShape(),
150  {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
151  std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
152  [=](auto indices) {
153  int row = int(indices[0]);
154  int col = int(indices[1]);
155  if (transposeIndices)
156  std::swap(row, col);
157  return SMESubTile{row, col, smeTileType};
158  });
159 }
160 
161 /// Returns the number of SME tiles that fit into the (2D-scalable) vector type
162 /// `type`.
163 int getNumberOfSMETilesForVectorType(VectorType type) {
164  assert(isMultipleOfSMETileVectorType(type) &&
165  "`type` not multiple of SME tiles");
166  int64_t vectorRows = type.getDimSize(0);
167  int64_t vectorCols = type.getDimSize(1);
168  auto elementType = type.getElementType();
169  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
170  return (vectorRows * vectorCols) / (minNumElts * minNumElts);
171 }
172 
173 /// Legalize `arith.constant dense<value>` splat operations to fit within SME
174 /// tiles by decomposing them into tile-sized operations.
175 struct LegalizeArithConstantOpsByDecomposition
176  : public OpConversionPattern<arith::ConstantOp> {
178 
179  LogicalResult
180  matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
181  ConversionPatternRewriter &rewriter) const override {
182  auto vectorType = dyn_cast<VectorType>(constantOp.getType());
183  auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
184  if (!vectorType || !denseAttr || !denseAttr.isSplat())
185  return failure();
186 
187  if (!isMultipleOfSMETileVectorType(vectorType))
188  return rewriter.notifyMatchFailure(constantOp,
189  kMatchFailureNotSMETileTypeMultiple);
190 
191  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
192  auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
193  auto tileSplat = rewriter.create<arith::ConstantOp>(
194  constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
195  SmallVector<Value> repl(tileCount, tileSplat);
196  rewriter.replaceOpWithMultiple(constantOp, {repl});
197 
198  return success();
199  }
200 };
201 
202 /// Legalize `vector.outerproduct` operations to fit within SME tiles by
203 /// decomposing them into tile-sized operations.
204 struct LegalizeVectorOuterProductOpsByDecomposition
205  : public OpConversionPattern<vector::OuterProductOp> {
207 
208  LogicalResult
209  matchAndRewrite(vector::OuterProductOp outerProductOp,
210  OneToNOpAdaptor adaptor,
211  ConversionPatternRewriter &rewriter) const override {
212  auto vectorType = outerProductOp.getResultVectorType();
213  if (!isMultipleOfSMETileVectorType(vectorType))
214  return rewriter.notifyMatchFailure(outerProductOp,
215  kMatchFailureNotSMETileTypeMultiple);
216 
217  Value mask;
218  Operation *rootOp = outerProductOp;
219  auto loc = outerProductOp.getLoc();
220  if (outerProductOp.isMasked()) {
221  auto maskOp = outerProductOp.getMaskingOp();
222  mask = maskOp.getMask();
223  rootOp = maskOp;
224  rewriter.setInsertionPoint(rootOp);
225  }
226 
227  if (!isSupportedMaskOp(mask))
228  return rewriter.notifyMatchFailure(outerProductOp,
229  kMatchFailureUnsupportedMaskOp);
230 
231  ValueRange accSMETiles = adaptor.getAcc();
232  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
233  VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
234 
235  SmallVector<Value> resultSMETiles;
236  for (auto [index, smeTile] : llvm::enumerate(
237  decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
238 
239  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
240  auto lhs = rewriter.create<vector::ScalableExtractOp>(
241  loc, sliceType, outerProductOp.getLhs(), smeTile.row);
242  auto rhs = rewriter.create<vector::ScalableExtractOp>(
243  loc, sliceType, outerProductOp.getRhs(), smeTile.col);
244  auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
245  loc, smeTileType, lhs, rhs,
246  !accSMETiles.empty() ? accSMETiles[index] : Value{},
247  outerProductOp.getKind());
248 
249  auto maskedOuterProduct =
250  vector::maskOperation(rewriter, smeOuterProduct, smeMask);
251  resultSMETiles.push_back(maskedOuterProduct->getResult(0));
252  }
253 
254  rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
255  return success();
256  }
257 };
258 
259 // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
260 // get the help of the type conversion), but doing so results in the type
261 // conversion adding target materializations in the `vector.mask` region
262 // (invalid). This pattern matches on `vector.mask` then calls into the
263 // `vector.outerproduct` pattern to work around this issue.
264 struct LegalizeMaskedVectorOuterProductOpsByDecomposition
265  : public OpConversionPattern<vector::MaskOp> {
267 
268  LogicalResult
269  matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
270  ConversionPatternRewriter &rewriter) const override {
271  if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
272  maskOp.getMaskableOp())) {
273  LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
274  getContext());
275  return static_cast<RewritePattern &>(pattern).matchAndRewrite(
276  outerProductOp, rewriter);
277  }
278  return failure();
279  }
280 };
281 
282 /// Legalize `vector.transfer_read` operations to fit within SME tiles by
283 /// decomposing them into tile-sized operations.
284 struct LegalizeTransferReadOpsByDecomposition
285  : public OpConversionPattern<vector::TransferReadOp> {
287 
288  LogicalResult
289  matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
290  ConversionPatternRewriter &rewriter) const override {
291  auto vectorType = readOp.getVectorType();
292  if (!isMultipleOfSMETileVectorType(vectorType))
293  return rewriter.notifyMatchFailure(readOp,
294  kMatchFailureNotSMETileTypeMultiple);
295 
296  auto mask = readOp.getMask();
297  if (!isSupportedMaskOp(mask))
298  return rewriter.notifyMatchFailure(readOp,
299  kMatchFailureUnsupportedMaskOp);
300 
301  auto permutationMap = readOp.getPermutationMap();
302  if (!permutationMap.isPermutation())
303  return rewriter.notifyMatchFailure(readOp,
304  kMatchFailureNonPermutationMap);
305 
306  // Note: For 2D vector types the only non-identity permutation is a simple
307  // transpose [1, 0].
308  bool transposed = !permutationMap.isIdentity();
309 
310  auto loc = readOp.getLoc();
311  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
312 
313  SmallVector<Value> resultSMETiles;
314  for (SMESubTile smeTile :
315  decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
316  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
317  auto smeRead = rewriter.create<vector::TransferReadOp>(
318  loc, smeTileType, readOp.getSource(),
319  getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
320  readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
321  readOp.getInBoundsAttr());
322  resultSMETiles.push_back(smeRead);
323  }
324 
325  rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
326  return success();
327  }
328 };
329 
330 /// Legalize `vector.transfer_write` operations to fit within SME tiles by
331 /// decomposing them into tile-sized operations.
332 struct LegalizeTransferWriteOpsByDecomposition
333  : public OpConversionPattern<vector::TransferWriteOp> {
335 
336  LogicalResult
337  matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
338  ConversionPatternRewriter &rewriter) const override {
339  auto vectorType = writeOp.getVectorType();
340  if (!isMultipleOfSMETileVectorType(vectorType))
341  return rewriter.notifyMatchFailure(writeOp,
342  kMatchFailureNotSMETileTypeMultiple);
343 
344  auto mask = writeOp.getMask();
345  if (!isSupportedMaskOp(mask))
346  return rewriter.notifyMatchFailure(writeOp,
347  kMatchFailureUnsupportedMaskOp);
348 
349  auto permutationMap = writeOp.getPermutationMap();
350  if (!permutationMap.isPermutation())
351  return rewriter.notifyMatchFailure(writeOp,
352  kMatchFailureNonPermutationMap);
353 
354  // Note: For 2D vector types the only non-identity permutation is a simple
355  // transpose [1, 0].
356  bool transposed = !permutationMap.isIdentity();
357 
358  auto loc = writeOp.getLoc();
359  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
360  auto inputSMETiles = adaptor.getVector();
361 
362  Value destTensorOrMemref = writeOp.getSource();
363  for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
364  rewriter, vectorType, smeTileType, transposed))) {
365  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
366  auto smeWrite = rewriter.create<vector::TransferWriteOp>(
367  loc, inputSMETiles[index], destTensorOrMemref,
368  getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
369  writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
370  if (writeOp.hasPureTensorSemantics())
371  destTensorOrMemref = smeWrite.getResult();
372  }
373 
374  if (writeOp.hasPureTensorSemantics())
375  rewriter.replaceOp(writeOp, destTensorOrMemref);
376  else
377  rewriter.eraseOp(writeOp);
378 
379  return success();
380  }
381 };
382 
383 /// Legalize a multi-tile transfer_write as a single store loop. This is done as
384 /// part of type decomposition as at this level we know each tile write is
385 /// disjoint, but that information is lost after decomposition (without analysis
386 /// to reconstruct it).
387 ///
388 /// Example (pseudo-MLIR):
389 ///
390 /// ```
391 /// vector.transfer_write %vector, %dest[%y, %x], %mask
392 /// : vector<[16]x[8]xi16>, memref<?x?xi16>
393 /// ```
394 /// Is rewritten to:
395 /// ```
396 /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
397 /// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
398 /// : vector<[8]xi1> from vector<[16]x[8]xi1> |
399 /// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
400 /// : vector<[8]xi16> from vector<[8]x[8]xi16> |
401 /// vector.transfer_write %upper_slice, |
402 /// %dest[%slice_idx + %y, %x], %upper_slice_mask |
403 /// : vector<[8]xi16>, memref<?x?xi16> ┘
404 /// %lower_slice_idx = %slice_idx + %c8_vscale ─┐
405 /// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] |
406 /// : vector<[8]xi1> from vector<[16]x[8]xi1> |
407 /// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower
408 /// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile
409 /// vector.transfer_write %lower_slice, |
410 /// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask |
411 /// : vector<[8]xi16>, memref<?x?xi16> ┘
412 /// }
413 /// ```
414 struct LegalizeMultiTileTransferWriteAsStoreLoop
415  : public OpConversionPattern<vector::TransferWriteOp> {
417 
418  LogicalResult
419  matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
420  ConversionPatternRewriter &rewriter) const override {
421  if (writeOp.hasPureTensorSemantics())
422  return rewriter.notifyMatchFailure(
423  writeOp, "TODO: tensor semantics are unsupported");
424 
425  auto permutationMap = writeOp.getPermutationMap();
426  if (!permutationMap.isPermutation())
427  return rewriter.notifyMatchFailure(writeOp,
428  kMatchFailureNonPermutationMap);
429 
430  bool transposed = !permutationMap.isIdentity();
431  if (transposed)
432  return rewriter.notifyMatchFailure(writeOp,
433  "TODO: transpose unsupported");
434 
435  auto vectorType = writeOp.getVectorType();
436  if (!isMultipleOfSMETileVectorType(vectorType))
437  return rewriter.notifyMatchFailure(writeOp,
438  kMatchFailureNotSMETileTypeMultiple);
439 
440  // Note: We also disallow masks where any dimension is > 16 because that
441  // prevents the masking from being lowered to use arm_sve.psel.
442  auto mask = writeOp.getMask();
443  if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
444  vectorType.getDimSize(1) > 16)))
445  return rewriter.notifyMatchFailure(writeOp,
446  kMatchFailureUnsupportedMaskOp);
447 
448  auto loc = writeOp.getLoc();
449  auto createVscaleMultiple =
450  vector::makeVscaleConstantBuilder(rewriter, loc);
451 
452  // Get SME tile and slice types.
453  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
454  auto minTileSlices = smeTileType.getDimSize(0);
455  VectorType sliceMaskType =
456  VectorType::get(minTileSlices, rewriter.getI1Type(), true);
457 
458  // Create loop over all tile slices.
459  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
460  auto upperBound = createVscaleMultiple(minTileSlices);
461  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
462  auto storeLoop =
463  rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
464  rewriter.setInsertionPointToStart(storeLoop.getBody());
465 
466  // For each sub-tile of the multi-tile `vectorType`.
467  auto inputSMETiles = adaptor.getVector();
468  auto tileSliceIndex = storeLoop.getInductionVar();
469  for (auto [index, smeTile] : llvm::enumerate(
470  decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
471  // The coordinates of the tile within `vectorType`.
472  auto tileRow = createVscaleMultiple(smeTile.row);
473  auto tileCol = createVscaleMultiple(smeTile.col);
474 
475  // The current slice of `vectorType` we are processing.
476  auto sliceIndex =
477  rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
478 
479  // Where in the destination memref the current slice will be stored.
480  auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
481  writeOp.getIndices()[0]);
482  auto storeCol =
483  rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
484 
485  // Extract the mask for the current slice.
486  Value sliceMask = nullptr;
487  if (mask) {
488  sliceMask = rewriter.create<vector::ExtractOp>(
489  loc, mask, OpFoldResult(sliceIndex));
490  if (sliceMaskType != sliceMask.getType())
491  sliceMask = rewriter.create<vector::ScalableExtractOp>(
492  loc, sliceMaskType, sliceMask, smeTile.col);
493  }
494 
495  // Extract and store the current slice.
496  Value tile = inputSMETiles[index];
497  auto slice =
498  rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
499  rewriter.create<vector::TransferWriteOp>(
500  loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
501  AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
502  sliceMask,
503  rewriter.getBoolArrayAttr(
504  ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
505  }
506 
507  rewriter.eraseOp(writeOp);
508  return success();
509  }
510 };
511 
512 //===----------------------------------------------------------------------===//
513 // ArmSME-specific fixup canonicalizations/folds
514 //===----------------------------------------------------------------------===//
515 
516 /// Folds an extract from a 3D `vector.create_mask` (which is a vector of
517 /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
518 /// necessary for the mask to be lowered to ArmSME.
519 ///
520 /// Example:
521 ///
522 /// BEFORE:
523 /// ```mlir
524 /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
525 /// %subMask = vector.extract %mask[2]
526 /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
527 /// ```
528 ///
529 /// AFTER:
530 /// ```mlir
531 /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
532 /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
533 /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
534 /// ```
535 struct FoldExtractFromVectorOfSMELikeCreateMasks
536  : public OpRewritePattern<vector::ExtractOp> {
538 
539  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
540  PatternRewriter &rewriter) const override {
541  auto loc = extractOp.getLoc();
542  auto createMaskOp =
543  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
544  if (!createMaskOp)
545  return rewriter.notifyMatchFailure(
546  extractOp, "extract not from vector.create_mask op");
547 
548  VectorType extractedMaskType =
549  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
550  if (!extractedMaskType)
551  return rewriter.notifyMatchFailure(extractOp,
552  "extracted type is not a vector type");
553 
554  auto numScalable = extractedMaskType.getNumScalableDims();
555  if (numScalable != 2)
556  return rewriter.notifyMatchFailure(
557  extractOp, "expected extracted type to be an SME-like mask");
558 
559  // TODO: Support multiple extraction indices.
560  if (extractOp.getStaticPosition().size() != 1)
561  return rewriter.notifyMatchFailure(
562  extractOp, "only a single extraction index is supported");
563 
564  auto frontMaskDim = createMaskOp.getOperand(0);
565  if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
566  return rewriter.notifyMatchFailure(
567  extractOp,
568  "constant vector.create_masks dims should be folded elsewhere");
569 
570  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
571  auto extractionIndex = getValueOrCreateConstantIndexOp(
572  rewriter, loc, extractOp.getMixedPosition()[0]);
573  auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
574  loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
575  frontMaskDim);
576  auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
577  loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
578 
579  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
580  extractOp, extractedMaskType,
581  ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
582  return success();
583  }
584 };
585 
586 /// A vector type where no fixed dimension comes after a scalable dimension.
587 bool isLegalVectorType(VectorType vType) {
588  bool seenFixedDim = false;
589  for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
590  seenFixedDim |= !scalableFlag;
591  if (seenFixedDim && scalableFlag)
592  return false;
593  }
594  return true;
595 }
596 
597 /// Lifts an illegal vector.transpose and vector.transfer_read to a
598 /// memref.subview + memref.transpose, followed by a legal read.
599 ///
600 /// 'Illegal' here means a leading scalable dimension and a fixed trailing
601 /// dimension, which has no valid lowering.
602 ///
603 /// The memref.transpose is metadata-only transpose that produces a strided
604 /// memref, which eventually becomes a loop reading individual elements.
605 ///
606 /// Example:
607 ///
608 /// BEFORE:
609 /// ```mlir
610 /// %illegalRead = vector.transfer_read %memref[%a, %b]
611 /// : memref<?x?xf32>, vector<[8]x4xf32>
612 /// %legalType = vector.transpose %illegalRead, [1, 0]
613 /// : vector<[8]x4xf32> to vector<4x[8]xf32>
614 /// ```
615 ///
616 /// AFTER:
617 /// ```mlir
618 /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
619 /// : memref<?x?xf32> to memref<?x?xf32>
620 /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
621 /// : memref<?x?xf32> to memref<?x?xf32>
622 /// %legalType = vector.transfer_read %transpose[%c0, %c0]
623 /// : memref<?x?xf32>, vector<4x[8]xf32>
624 /// ```
625 struct LiftIllegalVectorTransposeToMemory
626  : public OpRewritePattern<vector::TransposeOp> {
628 
629  static Value getExtensionSource(Operation *op) {
630  if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
631  return op->getOperand(0);
632  return {};
633  }
634 
635  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
636  PatternRewriter &rewriter) const override {
637  auto sourceType = transposeOp.getSourceVectorType();
638  auto resultType = transposeOp.getResultVectorType();
639  if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
640  return rewriter.notifyMatchFailure(transposeOp,
641  kMatchFailureNotIllegalToLegal);
642 
643  // Look through extend for transfer_read.
644  Value maybeRead = transposeOp.getVector();
645  auto *transposeSourceOp = maybeRead.getDefiningOp();
646  Operation *extendOp = nullptr;
647  if (Value extendSource = getExtensionSource(transposeSourceOp)) {
648  maybeRead = extendSource;
649  extendOp = transposeSourceOp;
650  }
651 
652  auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
653  if (!illegalRead)
654  return rewriter.notifyMatchFailure(
655  transposeOp,
656  "expected source to be (possibly extended) transfer_read");
657 
658  if (!illegalRead.getPermutationMap().isIdentity())
659  return rewriter.notifyMatchFailure(
660  illegalRead, "expected read to have identity permutation map");
661 
662  auto loc = transposeOp.getLoc();
663  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
664  auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
665 
666  // Create a subview that matches the size of the illegal read vector type.
667  auto readType = illegalRead.getVectorType();
668  auto readSizes = llvm::map_to_vector(
669  llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
670  [&](auto dim) -> Value {
671  auto [size, isScalable] = dim;
672  auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
673  if (!isScalable)
674  return dimSize;
675  auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
676  return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
677  });
678  SmallVector<Value> strides(readType.getRank(), Value(one));
679  auto readSubview = rewriter.create<memref::SubViewOp>(
680  loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
681  strides);
682 
683  // Apply the transpose to all values/attributes of the transfer_read:
684  // - The mask
685  Value mask = illegalRead.getMask();
686  if (mask) {
687  // Note: The transpose for the mask should fold into the
688  // vector.create_mask/constant_mask op, which will then become legal.
689  mask = rewriter.create<vector::TransposeOp>(loc, mask,
690  transposeOp.getPermutation());
691  }
692  // - The source memref
694  transposeOp.getPermutation(), getContext());
695  auto transposedSubview = rewriter.create<memref::TransposeOp>(
696  loc, readSubview, AffineMapAttr::get(transposeMap));
697  ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
698  // - The `in_bounds` attribute
699  if (inBoundsAttr) {
700  SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
701  inBoundsAttr.end());
702  applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
703  inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
704  }
705 
706  VectorType legalReadType = resultType.clone(readType.getElementType());
707  // Note: The indices are all zero as the subview is already offset.
708  SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
709  auto legalRead = rewriter.create<vector::TransferReadOp>(
710  loc, legalReadType, transposedSubview, readIndices,
711  illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
712  inBoundsAttr);
713 
714  // Replace the transpose with the new read, extending the result if
715  // necessary.
716  rewriter.replaceOp(transposeOp, [&]() -> Operation * {
717  if (extendOp)
718  return rewriter.create(loc, extendOp->getName().getIdentifier(),
719  Value(legalRead), resultType);
720  return legalRead;
721  }());
722 
723  return success();
724  }
725 };
726 
727 /// A rewrite to turn unit dim transpose-like vector.shape_casts into
728 /// vector.transposes. The shape_cast has to be from an illegal vector type to a
729 /// legal one (as defined by isLegalVectorType).
730 ///
731 /// The reasoning for this is if we've got to this pass and we still have
732 /// shape_casts of illegal types, then they likely will not cancel out. Turning
733 /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
734 /// eliminate them.
735 ///
736 /// Example:
737 ///
738 /// BEFORE:
739 /// ```mlir
740 /// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
741 /// ```
742 ///
743 /// AFTER:
744 /// ```mlir
745 /// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
746 /// ```
747 struct ConvertIllegalShapeCastOpsToTransposes
748  : public OpRewritePattern<vector::ShapeCastOp> {
750 
751  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
752  PatternRewriter &rewriter) const override {
753  auto sourceType = shapeCastOp.getSourceVectorType();
754  auto resultType = shapeCastOp.getResultVectorType();
755  if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
756  return rewriter.notifyMatchFailure(shapeCastOp,
757  kMatchFailureNotIllegalToLegal);
758 
759  // Note: If we know that `sourceType` is an illegal vector type (and 2D)
760  // then dim 0 is scalable and dim 1 is fixed.
761  if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
762  return rewriter.notifyMatchFailure(
763  shapeCastOp, "expected source to be a 2D scalable vector with a "
764  "trailing unit dim");
765 
766  auto loc = shapeCastOp.getLoc();
767  auto transpose = rewriter.create<vector::TransposeOp>(
768  loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
769 
770  if (resultType.getRank() == 1)
771  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
772  transpose);
773  else
774  rewriter.replaceOp(shapeCastOp, transpose);
775 
776  return success();
777  }
778 };
779 
780 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
781 /// the ZA state. This workaround rewrite to support these transposes when ZA is
782 /// available.
783 ///
784 /// Example:
785 ///
786 /// BEFORE:
787 /// ```mlir
788 /// %transpose = vector.transpose %vec, [1, 0]
789 /// : vector<2x[4]xf32> to vector<[4]x2xf32>
790 /// vector.transfer_write %transpose, %dest[%y, %x]
791 /// : vector<[4]x2xf32>, memref<?x?xf32>
792 /// ```
793 ///
794 /// AFTER:
795 /// ```mlir
796 /// %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
797 /// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
798 /// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
799 /// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
800 /// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
801 /// %c4_vscale = arith.muli %vscale, %c4 : index
802 /// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
803 /// vector.transfer_write %4, %dest[%y, %x], %mask
804 /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
805 /// : vector<[4]x[4]xf32>, memref<?x?xf32>
806 /// ```
807 ///
808 /// Values larger than a single tile are supported via decomposition.
809 struct LowerIllegalTransposeStoreViaZA
810  : public OpRewritePattern<vector::TransferWriteOp> {
812 
813  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
814  PatternRewriter &rewriter) const override {
815  if (!isSupportedMaskOp(writeOp.getMask()))
816  return rewriter.notifyMatchFailure(writeOp,
817  kMatchFailureUnsupportedMaskOp);
818 
819  auto permutationMap = writeOp.getPermutationMap();
820  if (!permutationMap.isIdentity())
821  return rewriter.notifyMatchFailure(writeOp,
822  kMatchFailureNonPermutationMap);
823 
824  auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
825  if (!transposeOp)
826  return failure();
827 
828  auto sourceType = transposeOp.getSourceVectorType();
829  auto resultType = transposeOp.getResultVectorType();
830 
831  if (resultType.getRank() != 2)
832  return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");
833 
834  if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
835  return rewriter.notifyMatchFailure(
836  transposeOp, "not illegal/unsupported SVE transpose");
837 
838  auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
839  VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
840 
841  if (sourceType.getDimSize(0) <= 1 ||
842  sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
843  return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
844 
845  auto loc = writeOp.getLoc();
846  auto createVscaleMultiple =
847  vector::makeVscaleConstantBuilder(rewriter, loc);
848 
849  auto transposeMap = AffineMapAttr::get(
851 
852  // Note: We need to use `get_tile` as there's no vector-level `undef`.
853  Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
854  Value destTensorOrMemref = writeOp.getSource();
855  auto numSlicesPerTile =
856  std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
857  auto numSlices =
858  rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
859  for (auto [index, smeTile] : llvm::enumerate(
860  decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
861  // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
862  // of slices from the source type into the SME tile. Without checking
863  // vscale (and emitting multiple implementations) we can't make use of the
864  // rows of the tile after 1*vscale rows.
865  Value tile = undefTile;
866  for (int d = 0; d < numSlicesPerTile; ++d) {
867  Value vector = rewriter.create<vector::ExtractOp>(
868  loc, transposeOp.getVector(),
869  rewriter.getIndexAttr(d + smeTile.row));
870  if (vector.getType() != smeSliceType) {
871  vector = rewriter.create<vector::ScalableExtractOp>(
872  loc, smeSliceType, vector, smeTile.col);
873  }
874  tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
875  }
876 
877  // 2. Transpose the tile position.
878  auto transposedRow = createVscaleMultiple(smeTile.col);
879  auto transposedCol =
880  rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);
881 
882  // 3. Compute mask for tile store.
883  Value maskRows;
884  Value maskCols;
885  if (auto mask = writeOp.getMask()) {
886  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
887  maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
888  transposedRow);
889  maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
890  transposedCol);
891  maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
892  } else {
893  maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
894  maskCols = numSlices;
895  }
896  auto subMask = rewriter.create<vector::CreateMaskOp>(
897  loc, smeTileType.clone(rewriter.getI1Type()),
898  ValueRange{maskRows, maskCols});
899 
900  // 4. Emit a transposed tile write.
901  auto writeIndices = writeOp.getIndices();
902  Value destRow =
903  rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
904  Value destCol =
905  rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
906  auto smeWrite = rewriter.create<vector::TransferWriteOp>(
907  loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
908  transposeMap, subMask, writeOp.getInBounds());
909 
910  if (writeOp.hasPureTensorSemantics())
911  destTensorOrMemref = smeWrite.getResult();
912  }
913 
914  if (writeOp.hasPureTensorSemantics())
915  rewriter.replaceOp(writeOp, destTensorOrMemref);
916  else
917  rewriter.eraseOp(writeOp);
918 
919  return success();
920  }
921 };
922 
923 struct VectorLegalizationPass
924  : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
925  void runOnOperation() override {
926  auto *context = &getContext();
927  TypeConverter converter;
928  RewritePatternSet patterns(context);
929  converter.addConversion([](Type type) { return type; });
930  converter.addConversion(
931  [](VectorType vectorType,
932  SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
933  if (!isMultipleOfSMETileVectorType(vectorType))
934  return std::nullopt;
935  auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
936  auto smeTileType =
937  getSMETileTypeForElement(vectorType.getElementType());
938  types = SmallVector<Type>(smeTileCount, smeTileType);
939  return success();
940  });
941 
942  // Apply preprocessing patterns.
943  RewritePatternSet rewritePatterns(context);
944  rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
945  LiftIllegalVectorTransposeToMemory,
946  ConvertIllegalShapeCastOpsToTransposes,
947  LowerIllegalTransposeStoreViaZA>(context);
948  if (failed(
949  applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
950  return signalPassFailure();
951 
952  // Note: These two patterns are added with a high benefit to ensure:
953  // - Masked outer products are handled before unmasked ones
954  // - Multi-tile writes are lowered as a store loop (if possible)
955  patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
956  LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
957  /*benefit=*/1024);
958  patterns.add<LegalizeArithConstantOpsByDecomposition,
959  LegalizeVectorOuterProductOpsByDecomposition,
960  LegalizeTransferReadOpsByDecomposition,
961  LegalizeTransferWriteOpsByDecomposition>(converter, context);
962  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
963  converter);
967 
968  ConversionTarget target(getContext());
969  target.markUnknownOpDynamicallyLegal(
970  [&](Operation *op) { return converter.isLegal(op); });
971  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
972  return converter.isSignatureLegal(op.getFunctionType());
973  });
974  if (failed(applyPartialConversion(getOperation(), target,
975  std::move(patterns))))
976  return signalPassFailure();
977  }
978 };
979 
980 } // namespace
981 
983  return std::make_unique<VectorLegalizationPass>();
984 }
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value createMask(AffineForOp vecForOp, VectorizationState &state)
Creates a mask used to filter out garbage elements in the last iteration of unaligned loops.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerType getI1Type()
Definition: Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:266
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:271
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:295
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
Definition: Utils.cpp:114
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
std::unique_ptr< Pass > createVectorLegalizationPass()
Pass that legalizes vectors so they can be lowered to ArmSME.
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
Definition: Utils.cpp:97
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Definition: VectorUtils.h:113
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1297
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:368