MLIR  19.0.0git
VectorLegalization.cpp
Go to the documentation of this file.
1 //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass legalizes vector operations so they can be lowered to ArmSME.
10 //
11 // Note: In the context of this pass 'tile' always refers to an SME tile.
12 //
13 //===----------------------------------------------------------------------===//
14 
25 
26 #define DEBUG_TYPE "arm-sme-vector-legalization"
27 
28 namespace mlir::arm_sme {
29 #define GEN_PASS_DEF_VECTORLEGALIZATION
30 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
31 } // namespace mlir::arm_sme
32 
33 using namespace mlir;
34 using namespace mlir::arm_sme;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Decomposition of vector operations larger than an SME tile
40 //===----------------------------------------------------------------------===//
41 
42 // Common match failure reasons.
43 static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
44  "op vector size is not multiple of SME tiles");
45 static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
46  "op mask is unsupported for legalization/decomposition");
47 static constexpr StringLiteral
48  kMatchFailureNonPermutationMap("op affine map is not a permutation");
49 static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
50  "expected transpose from illegal type to legal type");
51 
52 /// An SMESubTile represents a single SME-sized sub-tile from decomposing a
53 /// larger vector type. The (`row`, `col`) are the position of the tile in the
54 /// original vector type. For example for an [8]x[8] tile with four [4]x[4]
55 /// sub-tiles, we would have:
56 ///
57 /// 8 x vscale
58 /// ┌─────────────┬─────────────┐
59 /// │(0,0) │(0,4) │
60 /// │ │ │
61 /// ├─────────────┼─────────────┤ 8 x vscale
62 /// │(4,0) │(4,4) │
63 /// │ │ │
64 /// └─────────────┴─────────────┘
65 struct SMESubTile {
66  // Note: The units of (row, col) are vscale (as SME tiles are scalable).
67  int row{0};
68  int col{0};
69  // The SME tile type.
70  VectorType type;
71 };
72 
73 /// Adds a constant elementwise scalable offset to `indices` (which are of equal
74 /// length). For example, in the 2D case this would return:
75 // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
76 SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
77  Location loc,
78  ValueRange indices,
79  ArrayRef<int> scalableOffsets) {
80  auto vscale = builder.create<vector::VectorScaleOp>(loc);
81  return llvm::map_to_vector(
82  llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
83  auto [index, base] = pair;
84  auto offset = builder.create<arith::MulIOp>(
85  loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
86  return builder.create<arith::AddIOp>(loc, index, offset);
87  });
88 }
89 
90 /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
91 /// indices for one of the SME sub-tiles it will decompose into.
92 ///
93 /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
94 /// indices for each tile would need to be adjusted as follows:
95 ///
96 /// initial indices = [a,b], inital size = 8x8, target size = 4x4
97 /// ┌─────────────┬─────────────┐
98 /// │[a,b] │[a,b+4] │
99 /// │ │ │
100 /// ├─────────────┼─────────────┤
101 /// │[a+4,b] │[a+4,b+4] │
102 /// │ │ │
103 /// └─────────────┴─────────────┘
104 SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
105  ValueRange indices,
106  SMESubTile smeTile) {
107  return addConstantScalableOffset(builder, loc, indices,
108  {smeTile.row, smeTile.col});
109 }
110 
111 /// Returns true if `mask` is generated by an operation that can be decomposed
112 /// for SME. Currently, that is just no mask, or vector.create_mask.
113 /// TODO: Add support for vector.constant_mask once required for SME.
114 bool isSupportedMaskOp(Value mask) {
115  return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
116 }
117 
118 /// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
119 Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
120  SMESubTile smeTile) {
121  assert(isSupportedMaskOp(mask));
122  if (!mask)
123  return Value{};
124  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
125  // The operands of `vector.create_mask` (from a 2D perspective) are the
126  // coordinates where the mask ends. So we subtract where this tile starts,
127  // from the mask operands to get the parameters for this sub-tile.
128  auto smeTileMaskDims = addConstantScalableOffset(
129  builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
130  auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
131  loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
132  return smeTileCreateMask.getResult();
133 }
134 
135 /// Constructs an iterator that returns each SME tile (with coordinates)
136 /// contained within a VectorType. For example, if decomposing an [8]x[8] into
137 /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
138 /// (4, 4).
139 auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
140  VectorType smeTileType,
141  bool transposeIndices = false) {
142  assert(isMultipleOfSMETileVectorType(type) &&
143  "`type` not multiple of SME tiles");
144  return llvm::map_range(
145  StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
146  smeTileType.getDimSize(1)}),
147  [=](auto indices) {
148  int row = int(indices[0]);
149  int col = int(indices[1]);
150  if (transposeIndices)
151  std::swap(row, col);
152  return SMESubTile{row, col, smeTileType};
153  });
154 }
155 
156 /// Returns the number of SME tiles that fit into the (2D-scalable) vector type
157 /// `type`.
158 int getNumberOfSMETilesForVectorType(VectorType type) {
159  assert(isMultipleOfSMETileVectorType(type) &&
160  "`type` not multiple of SME tiles");
161  int64_t vectorRows = type.getDimSize(0);
162  int64_t vectorCols = type.getDimSize(1);
163  auto elementType = type.getElementType();
164  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
165  return (vectorRows * vectorCols) / (minNumElts * minNumElts);
166 }
167 
168 /// Legalize `arith.constant dense<value>` splat operations to fit within SME
169 /// tiles by decomposing them into tile-sized operations.
170 struct LegalizeArithConstantOpsByDecomposition
171  : public OneToNOpConversionPattern<arith::ConstantOp> {
173 
175  matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
176  OneToNPatternRewriter &rewriter) const override {
177  auto vectorType = dyn_cast<VectorType>(constantOp.getType());
178  auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
179  if (!vectorType || !denseAttr || !denseAttr.isSplat())
180  return failure();
181 
182  if (!isMultipleOfSMETileVectorType(vectorType))
183  return rewriter.notifyMatchFailure(constantOp,
184  kMatchFailureNotSMETileTypeMultiple);
185 
186  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
187  auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
188  auto tileSplat = rewriter.create<arith::ConstantOp>(
189  constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
190  rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
191  adaptor.getResultMapping());
192 
193  return success();
194  }
195 };
196 
197 /// Legalize `vector.outerproduct` operations to fit within SME tiles by
198 /// decomposing them into tile-sized operations.
199 struct LegalizeVectorOuterProductOpsByDecomposition
200  : public OneToNOpConversionPattern<vector::OuterProductOp> {
202 
204  matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
205  OneToNPatternRewriter &rewriter) const override {
206  auto vectorType = outerProductOp.getResultVectorType();
207  if (!isMultipleOfSMETileVectorType(vectorType))
208  return rewriter.notifyMatchFailure(outerProductOp,
209  kMatchFailureNotSMETileTypeMultiple);
210 
211  Value mask;
212  Operation *rootOp = outerProductOp;
213  auto loc = outerProductOp.getLoc();
214  if (outerProductOp.isMasked()) {
215  auto maskOp = outerProductOp.getMaskingOp();
216  mask = maskOp.getMask();
217  rootOp = maskOp;
218  }
219 
220  if (!isSupportedMaskOp(mask))
221  return rewriter.notifyMatchFailure(outerProductOp,
222  kMatchFailureUnsupportedMaskOp);
223 
224  ValueRange accSMETiles = adaptor.getAcc();
225  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
226  VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
227 
228  SmallVector<Value> resultSMETiles;
229  for (auto [index, smeTile] : llvm::enumerate(
230  decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
231 
232  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
233  auto lhs = rewriter.create<vector::ScalableExtractOp>(
234  loc, sliceType, outerProductOp.getLhs(), smeTile.row);
235  auto rhs = rewriter.create<vector::ScalableExtractOp>(
236  loc, sliceType, outerProductOp.getRhs(), smeTile.col);
237  auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
238  loc, smeTileType, lhs, rhs,
239  !accSMETiles.empty() ? accSMETiles[index] : Value{},
240  outerProductOp.getKind());
241 
242  auto maskedOuterProduct =
243  vector::maskOperation(rewriter, smeOuterProduct, smeMask);
244  resultSMETiles.push_back(maskedOuterProduct->getResult(0));
245  }
246 
247  rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
248  return success();
249  }
250 };
251 
252 // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
253 // get the help of the type conversion), but doing so results in the type
254 // conversion adding target materializations in the `vector.mask` region
255 // (invalid). This pattern matches on `vector.mask` then calls into the
256 // `vector.outerproduct` pattern to work around this issue.
257 struct LegalizeMaskedVectorOuterProductOpsByDecomposition
258  : public OneToNOpConversionPattern<vector::MaskOp> {
260 
262  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
263  OneToNPatternRewriter &rewriter) const override {
264  if (auto outerProductOp =
265  llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
266  LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
267  getContext());
268  return static_cast<RewritePattern &>(pattern).matchAndRewrite(
269  outerProductOp, rewriter);
270  }
271  return failure();
272  }
273 };
274 
275 /// Legalize `vector.transfer_read` operations to fit within SME tiles by
276 /// decomposing them into tile-sized operations.
277 struct LegalizeTransferReadOpsByDecomposition
278  : public OneToNOpConversionPattern<vector::TransferReadOp> {
280 
282  matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
283  OneToNPatternRewriter &rewriter) const override {
284  auto vectorType = readOp.getVectorType();
285  if (!isMultipleOfSMETileVectorType(vectorType))
286  return rewriter.notifyMatchFailure(readOp,
287  kMatchFailureNotSMETileTypeMultiple);
288 
289  auto mask = readOp.getMask();
290  if (!isSupportedMaskOp(mask))
291  return rewriter.notifyMatchFailure(readOp,
292  kMatchFailureUnsupportedMaskOp);
293 
294  auto permutationMap = readOp.getPermutationMap();
295  if (!permutationMap.isPermutation())
296  return rewriter.notifyMatchFailure(readOp,
297  kMatchFailureNonPermutationMap);
298 
299  // Note: For 2D vector types the only non-identity permutation is a simple
300  // tranpose [1, 0].
301  bool transposed = !permutationMap.isIdentity();
302 
303  auto loc = readOp.getLoc();
304  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
305 
306  SmallVector<Value> resultSMETiles;
307  for (SMESubTile smeTile :
308  decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
309  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
310  auto smeRead = rewriter.create<vector::TransferReadOp>(
311  loc, smeTileType, readOp.getSource(),
312  getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
313  readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
314  readOp.getInBoundsAttr());
315  resultSMETiles.push_back(smeRead);
316  }
317 
318  rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
319  return success();
320  }
321 };
322 
323 /// Legalize `vector.transfer_write` operations to fit within SME tiles by
324 /// decomposing them into tile-sized operations.
325 struct LegalizeTransferWriteOpsByDecomposition
326  : public OneToNOpConversionPattern<vector::TransferWriteOp> {
328 
330  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
331  OneToNPatternRewriter &rewriter) const override {
332  auto vectorType = writeOp.getVectorType();
333  if (!isMultipleOfSMETileVectorType(vectorType))
334  return rewriter.notifyMatchFailure(writeOp,
335  kMatchFailureNotSMETileTypeMultiple);
336 
337  auto mask = writeOp.getMask();
338  if (!isSupportedMaskOp(mask))
339  return rewriter.notifyMatchFailure(writeOp,
340  kMatchFailureUnsupportedMaskOp);
341 
342  auto permutationMap = writeOp.getPermutationMap();
343  if (!permutationMap.isPermutation())
344  return rewriter.notifyMatchFailure(writeOp,
345  kMatchFailureNonPermutationMap);
346 
347  // Note: For 2D vector types the only non-identity permutation is a simple
348  // tranpose [1, 0].
349  bool transposed = !permutationMap.isIdentity();
350 
351  auto loc = writeOp.getLoc();
352  auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
353  auto inputSMETiles = adaptor.getVector();
354 
355  Value destTensorOrMemref = writeOp.getSource();
356  for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
357  rewriter, vectorType, smeTileType, transposed))) {
358  auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
359  auto smeWrite = rewriter.create<vector::TransferWriteOp>(
360  loc, inputSMETiles[index], destTensorOrMemref,
361  getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
362  writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
363  if (writeOp.hasPureTensorSemantics())
364  destTensorOrMemref = smeWrite.getResult();
365  }
366 
367  if (writeOp.hasPureTensorSemantics())
368  rewriter.replaceOp(writeOp, destTensorOrMemref);
369  else
370  rewriter.eraseOp(writeOp);
371 
372  return success();
373  }
374 };
375 
376 //===----------------------------------------------------------------------===//
377 // ArmSME-specific fixup canonicalizations/folds
378 //===----------------------------------------------------------------------===//
379 
380 /// Folds an extract from a 3D `vector.create_mask` (which is a vector of
381 /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
382 /// necessary for the mask to be lowered to ArmSME.
383 ///
384 /// Example:
385 ///
386 /// BEFORE:
387 /// ```mlir
388 /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
389 /// %subMask = vector.extract %mask[2]
390 /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
391 /// ```
392 ///
393 /// AFTER:
394 /// ```mlir
395 /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
396 /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
397 /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
398 /// ```
399 struct FoldExtractFromVectorOfSMELikeCreateMasks
400  : public OpRewritePattern<vector::ExtractOp> {
402 
403  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
404  PatternRewriter &rewriter) const override {
405  auto loc = extractOp.getLoc();
406  auto createMaskOp =
407  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
408  if (!createMaskOp)
409  return rewriter.notifyMatchFailure(
410  extractOp, "extract not from vector.create_mask op");
411 
412  VectorType extractedMaskType =
413  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
414  if (!extractedMaskType)
415  return rewriter.notifyMatchFailure(extractOp,
416  "extracted type is not a vector type");
417 
418  auto numScalable = llvm::count(extractedMaskType.getScalableDims(), true);
419  if (numScalable != 2)
420  return rewriter.notifyMatchFailure(
421  extractOp, "expected extracted type to be an SME-like mask");
422 
423  // TODO: Support multiple extraction indices.
424  if (extractOp.getStaticPosition().size() != 1)
425  return rewriter.notifyMatchFailure(
426  extractOp, "only a single extraction index is supported");
427 
428  auto frontMaskDim = createMaskOp.getOperand(0);
429  if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
430  return rewriter.notifyMatchFailure(
431  extractOp,
432  "constant vector.create_masks dims should be folded elsewhere");
433 
434  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
435  auto extractionIndex = getValueOrCreateConstantIndexOp(
436  rewriter, loc, extractOp.getMixedPosition()[0]);
437  auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
438  loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
439  frontMaskDim);
440  auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
441  loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
442 
443  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
444  extractOp, extractedMaskType,
445  ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
446  return success();
447  }
448 };
449 
450 /// A vector type where no fixed dimension comes after a scalable dimension.
451 bool isLegalVectorType(VectorType vType) {
452  bool seenFixedDim = false;
453  for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
454  seenFixedDim |= !scalableFlag;
455  if (seenFixedDim && scalableFlag)
456  return false;
457  }
458  return true;
459 }
460 
461 /// Lifts an illegal vector.transpose and vector.transfer_read to a
462 /// memref.subview + memref.transpose, followed by a legal read.
463 ///
464 /// 'Illegal' here means a leading scalable dimension and a fixed trailing
465 /// dimension, which has no valid lowering.
466 ///
467 /// The memref.transpose is metadata-only transpose that produces a strided
468 /// memref, which eventually becomes a loop reading individual elements.
469 ///
470 /// Example:
471 ///
472 /// BEFORE:
473 /// ```mlir
474 /// %illegalRead = vector.transfer_read %memref[%a, %b]
475 /// : memref<?x?xf32>, vector<[8]x4xf32>
476 /// %legalType = vector.transpose %illegalRead, [1, 0]
477 /// : vector<[8]x4xf32> to vector<4x[8]xf32>
478 /// ```
479 ///
480 /// AFTER:
481 /// ```mlir
482 /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
483 /// : memref<?x?xf32> to memref<?x?xf32>
484 /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
485 /// : memref<?x?xf32> to memref<?x?xf32>
486 /// %legalType = vector.transfer_read %transpose[%c0, %c0]
487 /// : memref<?x?xf32>, vector<4x[8]xf32>
488 /// ```
489 struct LiftIllegalVectorTransposeToMemory
490  : public OpRewritePattern<vector::TransposeOp> {
492 
493  static Value getExtensionSource(Operation *op) {
494  if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
495  return op->getOperand(0);
496  return {};
497  }
498 
499  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
500  PatternRewriter &rewriter) const override {
501  auto sourceType = transposeOp.getSourceVectorType();
502  auto resultType = transposeOp.getResultVectorType();
503  if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
504  return rewriter.notifyMatchFailure(transposeOp,
505  kMatchFailureNotIllegalToLegal);
506 
507  // Look through extend for transfer_read.
508  Value maybeRead = transposeOp.getVector();
509  auto *transposeSourceOp = maybeRead.getDefiningOp();
510  Operation *extendOp = nullptr;
511  if (Value extendSource = getExtensionSource(transposeSourceOp)) {
512  maybeRead = extendSource;
513  extendOp = transposeSourceOp;
514  }
515 
516  auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
517  if (!illegalRead)
518  return rewriter.notifyMatchFailure(
519  transposeOp,
520  "expected source to be (possibly extended) transfer_read");
521 
522  if (!illegalRead.getPermutationMap().isIdentity())
523  return rewriter.notifyMatchFailure(
524  illegalRead, "expected read to have identity permutation map");
525 
526  auto loc = transposeOp.getLoc();
527  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
528  auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
529 
530  // Create a subview that matches the size of the illegal read vector type.
531  auto readType = illegalRead.getVectorType();
532  auto readSizes = llvm::map_to_vector(
533  llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
534  [&](auto dim) -> Value {
535  auto [size, isScalable] = dim;
536  auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
537  if (!isScalable)
538  return dimSize;
539  auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
540  return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
541  });
542  SmallVector<Value> strides(readType.getRank(), Value(one));
543  auto readSubview = rewriter.create<memref::SubViewOp>(
544  loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
545  strides);
546 
547  // Apply the transpose to all values/attributes of the transfer_read:
548  // - The mask
549  Value mask = illegalRead.getMask();
550  if (mask) {
551  // Note: The transpose for the mask should fold into the
552  // vector.create_mask/constant_mask op, which will then become legal.
553  mask = rewriter.create<vector::TransposeOp>(loc, mask,
554  transposeOp.getPermutation());
555  }
556  // - The source memref
558  transposeOp.getPermutation(), getContext());
559  auto transposedSubview = rewriter.create<memref::TransposeOp>(
560  loc, readSubview, AffineMapAttr::get(transposeMap));
561  ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
562  // - The `in_bounds` attribute
563  if (inBoundsAttr) {
564  SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
565  inBoundsAttr.end());
566  applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
567  inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
568  }
569 
570  VectorType legalReadType = resultType.clone(readType.getElementType());
571  // Note: The indices are all zero as the subview is already offset.
572  SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
573  auto legalRead = rewriter.create<vector::TransferReadOp>(
574  loc, legalReadType, transposedSubview, readIndices,
575  illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
576  inBoundsAttr);
577 
578  // Replace the transpose with the new read, extending the result if
579  // necessary.
580  rewriter.replaceOp(transposeOp, [&]() -> Operation * {
581  if (extendOp)
582  return rewriter.create(loc, extendOp->getName().getIdentifier(),
583  Value(legalRead), resultType);
584  return legalRead;
585  }());
586 
587  return success();
588  }
589 };
590 
591 /// A rewrite to turn unit dim transpose-like vector.shape_casts into
592 /// vector.transposes. The shape_cast has to be from an illegal vector type to a
593 /// legal one (as defined by isLegalVectorType).
594 ///
595 /// The reasoning for this is if we've got to this pass and we still have
596 /// shape_casts of illegal types, then they likely will not cancel out. Turning
597 /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
598 /// eliminate them.
599 ///
600 /// Example:
601 ///
602 /// BEFORE:
603 /// ```mlir
604 /// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
605 /// ```
606 ///
607 /// AFTER:
608 /// ```mlir
609 /// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
610 /// ```
611 struct ConvertIllegalShapeCastOpsToTransposes
612  : public OpRewritePattern<vector::ShapeCastOp> {
614 
615  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
616  PatternRewriter &rewriter) const override {
617  auto sourceType = shapeCastOp.getSourceVectorType();
618  auto resultType = shapeCastOp.getResultVectorType();
619  if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
620  return rewriter.notifyMatchFailure(shapeCastOp,
621  kMatchFailureNotIllegalToLegal);
622 
623  // Note: If we know that `sourceType` is an illegal vector type (and 2D)
624  // then dim 0 is scalable and dim 1 is fixed.
625  if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
626  return rewriter.notifyMatchFailure(
627  shapeCastOp, "expected source to be a 2D scalable vector with a "
628  "trailing unit dim");
629 
630  auto loc = shapeCastOp.getLoc();
631  auto transpose = rewriter.create<vector::TransposeOp>(
632  loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
633 
634  if (resultType.getRank() == 1)
635  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
636  transpose);
637  else
638  rewriter.replaceOp(shapeCastOp, transpose);
639 
640  return success();
641  }
642 };
643 
644 struct VectorLegalizationPass
645  : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
646  void runOnOperation() override {
647  auto *context = &getContext();
648  OneToNTypeConverter converter;
649  RewritePatternSet patterns(context);
650  converter.addConversion([](Type type) { return type; });
651  converter.addConversion(
652  [](VectorType vectorType,
653  SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
654  if (!isMultipleOfSMETileVectorType(vectorType))
655  return std::nullopt;
656  auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
657  auto smeTileType =
658  getSMETileTypeForElement(vectorType.getElementType());
659  types = SmallVector<Type>(smeTileCount, smeTileType);
660  return success();
661  });
662 
663  patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
664  LiftIllegalVectorTransposeToMemory,
665  ConvertIllegalShapeCastOpsToTransposes>(context);
666  // Note: High benefit to ensure masked outer products are lowered first.
667  patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
668  converter, context, 1024);
669  patterns.add<LegalizeArithConstantOpsByDecomposition,
670  LegalizeVectorOuterProductOpsByDecomposition,
671  LegalizeTransferReadOpsByDecomposition,
672  LegalizeTransferWriteOpsByDecomposition>(converter, context);
673  populateFuncTypeConversionPatterns(converter, patterns);
675 
676  if (failed(applyPartialOneToNConversion(getOperation(), converter,
677  std::move(patterns))))
678  return signalPassFailure();
679  }
680 };
681 
682 } // namespace
683 
685  return std::make_unique<VectorLegalizationPass>();
686 }
static MLIRContext * getContext(OpFoldResult val)
static Value createMask(AffineForOp vecForOp, VectorizationState &state)
Creates a mask used to filter out garbage elements in the last iteration of unaligned loops.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
IntegerType getI1Type()
Definition: Builders.cpp:73
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class is a wrapper around OneToNConversionPattern for matching against instances of a particular...
OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Specialization of PatternRewriter that OneToNConversionPatterns use.
void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping)
Replaces the results of the operation with the specified list of values mapped back to the original t...
Extends TypeConverter with 1:N target materializations.
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:330
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
Definition: Utils.cpp:114
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
std::unique_ptr< Pass > createVectorLegalizationPass()
Pass that legalizes vectors so they can be lowered to ArmSME.
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
Definition: Utils.cpp:97
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the provided pattern set with patterns that do 1:N type conversions on (some) SCF ops.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358