MLIR  21.0.0git
VectorToSCF.cpp
Go to the documentation of this file.
1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <numeric>
14 #include <optional>
15 #include <type_traits>
16 
18 
28 #include "mlir/IR/Builders.h"
30 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/Passes.h"
33 
34 namespace mlir {
35 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
36 #include "mlir/Conversion/Passes.h.inc"
37 } // namespace mlir
38 
39 using namespace mlir;
40 using vector::TransferReadOp;
41 using vector::TransferWriteOp;
42 
43 namespace {
44 
45 /// Attribute name used for labeling transfer ops during progressive lowering.
46 static const char kPassLabel[] = "__vector_to_scf_lowering__";
47 
48 /// Return true if this transfer op operates on a source tensor.
49 static bool isTensorOp(VectorTransferOpInterface xferOp) {
50  if (isa<RankedTensorType>(xferOp.getShapedType())) {
51  if (isa<vector::TransferWriteOp>(xferOp)) {
52  // TransferWriteOps on tensors have a result.
53  assert(xferOp->getNumResults() > 0);
54  }
55  return true;
56  }
57  return false;
58 }
59 
60 /// Patterns that inherit from this struct have access to
61 /// VectorTransferToSCFOptions.
62 template <typename OpTy>
63 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
64  explicit VectorToSCFPattern(MLIRContext *context,
66  : OpRewritePattern<OpTy>(context), options(opt) {}
67 
68  LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
69  PatternRewriter &rewriter) const {
70  if (isTensorOp(xferOp) && !options.lowerTensors) {
71  return rewriter.notifyMatchFailure(
72  xferOp, "lowering tensor transfers is disabled");
73  }
74  return success();
75  }
76 
78 };
79 
80 /// Given a vector transfer op, calculate which dimension of the `source`
81 /// memref should be unpacked in the next application of TransferOpConversion.
82 /// A return value of std::nullopt indicates a broadcast.
83 template <typename OpTy>
84 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
85  // TODO: support 0-d corner case.
86  assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
87  auto map = xferOp.getPermutationMap();
88  if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
89  return expr.getPosition();
90  }
91  assert(xferOp.isBroadcastDim(0) &&
92  "Expected AffineDimExpr or AffineConstantExpr");
93  return std::nullopt;
94 }
95 
96 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
97 /// map is identical to the current permutation map, but the first result is
98 /// omitted.
99 template <typename OpTy>
100 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
101  // TODO: support 0-d corner case.
102  assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
103  auto map = xferOp.getPermutationMap();
104  return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
105  b.getContext());
106 }
107 
108 /// Calculate the indices for the new vector transfer op.
109 ///
110 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
111 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
112 /// ^^^^^^
113 /// `iv` is the iteration variable of the (new) surrounding loop.
114 template <typename OpTy>
115 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
116  SmallVector<Value, 8> &indices) {
117  typename OpTy::Adaptor adaptor(xferOp);
118  // Corresponding memref dim of the vector dim that is unpacked.
119  auto dim = unpackedDim(xferOp);
120  auto prevIndices = adaptor.getIndices();
121  indices.append(prevIndices.begin(), prevIndices.end());
122 
123  Location loc = xferOp.getLoc();
124  bool isBroadcast = !dim.has_value();
125  if (!isBroadcast) {
126  AffineExpr d0, d1;
127  bindDims(xferOp.getContext(), d0, d1);
128  Value offset = adaptor.getIndices()[*dim];
129  indices[*dim] =
130  affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
131  }
132 }
133 
134 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
135  Value value) {
136  if (hasRetVal) {
137  assert(value && "Expected non-empty value");
138  b.create<scf::YieldOp>(loc, value);
139  } else {
140  b.create<scf::YieldOp>(loc);
141  }
142 }
143 
144 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
145 /// is set to true. No such check is generated under following circumstances:
146 /// * xferOp does not have a mask.
147 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
148 /// computed and attached to the new transfer op in the pattern.)
149 /// * The to-be-unpacked dim of xferOp is a broadcast.
150 template <typename OpTy>
151 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
152  if (!xferOp.getMask())
153  return Value();
154  if (xferOp.getMaskType().getRank() != 1)
155  return Value();
156  if (xferOp.isBroadcastDim(0))
157  return Value();
158 
159  Location loc = xferOp.getLoc();
160  return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
161 }
162 
163 /// Helper function TransferOpConversion and TransferOp1dConversion.
164 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
165 /// specified dimension `dim` with the loop iteration variable `iv`.
166 /// E.g., when unpacking dimension 0 from:
167 /// ```
168 /// %vec = vector.transfer_read %A[%a, %b] %cst
169 /// : vector<5x4xf32>, memref<?x?xf32>
170 /// ```
171 /// An if check similar to this will be generated inside the loop:
172 /// ```
173 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
174 /// if (%a + iv < %d) {
175 /// (in-bounds case)
176 /// } else {
177 /// (out-of-bounds case)
178 /// }
179 /// ```
180 ///
181 /// If the transfer is 1D and has a mask, this function generates a more complex
182 /// check also accounts for potentially masked out elements.
183 ///
184 /// This function variant returns the value returned by `inBoundsCase` or
185 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
186 /// `resultTypes`.
187 template <typename OpTy>
188 static Value generateInBoundsCheck(
189  OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
190  TypeRange resultTypes,
191  function_ref<Value(OpBuilder &, Location)> inBoundsCase,
192  function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
193  bool hasRetVal = !resultTypes.empty();
194  Value cond; // Condition to be built...
195 
196  // Condition check 1: Access in-bounds?
197  bool isBroadcast = !dim; // No in-bounds check for broadcasts.
198  Location loc = xferOp.getLoc();
199  ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
200  if (!xferOp.isDimInBounds(0) && !isBroadcast) {
201  Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.getBase(), *dim);
202  AffineExpr d0, d1;
203  bindDims(xferOp.getContext(), d0, d1);
204  Value base = xferOp.getIndices()[*dim];
205  Value memrefIdx =
206  affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
207  cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
208  memrefIdx);
209  }
210 
211  // Condition check 2: Masked in?
212  if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
213  if (cond)
214  cond = lb.create<arith::AndIOp>(cond, maskCond);
215  else
216  cond = maskCond;
217  }
218 
219  // If the condition is non-empty, generate an SCF::IfOp.
220  if (cond) {
221  auto check = lb.create<scf::IfOp>(
222  cond,
223  /*thenBuilder=*/
224  [&](OpBuilder &b, Location loc) {
225  maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
226  },
227  /*elseBuilder=*/
228  [&](OpBuilder &b, Location loc) {
229  if (outOfBoundsCase) {
230  maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
231  } else {
232  b.create<scf::YieldOp>(loc);
233  }
234  });
235 
236  return hasRetVal ? check.getResult(0) : Value();
237  }
238 
239  // Condition is empty, no need for an SCF::IfOp.
240  return inBoundsCase(b, loc);
241 }
242 
243 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
244 /// a return value. Consequently, this function does not have a return value.
245 template <typename OpTy>
246 static void generateInBoundsCheck(
247  OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
248  function_ref<void(OpBuilder &, Location)> inBoundsCase,
249  function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
250  generateInBoundsCheck(
251  b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
252  /*inBoundsCase=*/
253  [&](OpBuilder &b, Location loc) {
254  inBoundsCase(b, loc);
255  return Value();
256  },
257  /*outOfBoundsCase=*/
258  [&](OpBuilder &b, Location loc) {
259  if (outOfBoundsCase)
260  outOfBoundsCase(b, loc);
261  return Value();
262  });
263 }
264 
265 /// Given an ArrayAttr, return a copy where the first element is dropped.
266 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
267  if (!attr)
268  return attr;
269  return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
270 }
271 
272 /// Add the pass label to a vector transfer op if its rank is not the target
273 /// rank.
274 template <typename OpTy>
275 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
276  unsigned targetRank) {
277  if (newXferOp.getVectorType().getRank() > targetRank)
278  newXferOp->setAttr(kPassLabel, b.getUnitAttr());
279 }
280 
281 namespace lowering_n_d {
282 
283 /// Helper data structure for data and mask buffers.
284 struct BufferAllocs {
285  Value dataBuffer;
286  Value maskBuffer;
287 };
288 
289 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
291  Operation *scope =
293  assert(scope && "Expected op to be inside automatic allocation scope");
294  return scope;
295 }
296 
297 /// Allocate temporary buffers for data (vector) and mask (if present).
298 template <typename OpTy>
299 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
300  Location loc = xferOp.getLoc();
301  OpBuilder::InsertionGuard guard(b);
302  Operation *scope = getAutomaticAllocationScope(xferOp);
303  assert(scope->getNumRegions() == 1 &&
304  "AutomaticAllocationScope with >1 regions");
305  b.setInsertionPointToStart(&scope->getRegion(0).front());
306 
307  BufferAllocs result;
308  auto bufferType = MemRefType::get({}, xferOp.getVectorType());
309  result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
310 
311  if (xferOp.getMask()) {
312  auto maskType = MemRefType::get({}, xferOp.getMask().getType());
313  auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
314  b.setInsertionPoint(xferOp);
315  b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
316  result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange());
317  }
318 
319  return result;
320 }
321 
322 /// Given a MemRefType with VectorType element type, unpack one dimension from
323 /// the VectorType into the MemRefType.
324 ///
325 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
326 static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
327  auto vectorType = dyn_cast<VectorType>(type.getElementType());
328  // Vectors with leading scalable dims are not supported.
329  // It may be possible to support these in future by using dynamic memref dims.
330  if (vectorType.getScalableDims().front())
331  return failure();
332  auto memrefShape = type.getShape();
333  SmallVector<int64_t, 8> newMemrefShape;
334  newMemrefShape.append(memrefShape.begin(), memrefShape.end());
335  newMemrefShape.push_back(vectorType.getDimSize(0));
336  return MemRefType::get(newMemrefShape,
337  VectorType::Builder(vectorType).dropDim(0));
338 }
339 
340 /// Given a transfer op, find the memref from which the mask is loaded. This
341 /// is similar to Strategy<TransferWriteOp>::getBuffer.
342 template <typename OpTy>
343 static Value getMaskBuffer(OpTy xferOp) {
344  assert(xferOp.getMask() && "Expected that transfer op has mask");
345  auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
346  assert(loadOp && "Expected transfer op mask produced by LoadOp");
347  return loadOp.getMemRef();
348 }
349 
350 /// Codegen strategy, depending on the operation.
351 template <typename OpTy>
352 struct Strategy;
353 
354 /// Code strategy for vector TransferReadOp.
355 template <>
356 struct Strategy<TransferReadOp> {
357  /// Find the StoreOp that is used for writing the current TransferReadOp's
358  /// result to the temporary buffer allocation.
359  static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
360  assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
361  auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
362  assert(storeOp && "Expected TransferReadOp result used by StoreOp");
363  return storeOp;
364  }
365 
366  /// Find the temporary buffer allocation. All labeled TransferReadOps are
367  /// used like this, where %buf is either the buffer allocation or a type cast
368  /// of the buffer allocation:
369  /// ```
370  /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
371  /// memref.store %vec, %buf[...] ...
372  /// ```
373  static Value getBuffer(TransferReadOp xferOp) {
374  return getStoreOp(xferOp).getMemRef();
375  }
376 
377  /// Retrieve the indices of the current StoreOp that stores into the buffer.
378  static void getBufferIndices(TransferReadOp xferOp,
379  SmallVector<Value, 8> &indices) {
380  auto storeOp = getStoreOp(xferOp);
381  auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
382  indices.append(prevIndices.begin(), prevIndices.end());
383  }
384 
385  /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
386  /// accesses on the to-be-unpacked dimension.
387  ///
388  /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
389  /// variable `iv`.
390  /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
391  ///
392  /// E.g.:
393  /// ```
394  /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
395  /// : memref<?x?x?xf32>, vector<4x3xf32>
396  /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
397  /// ```
398  /// Is rewritten to:
399  /// ```
400  /// %casted = vector.type_cast %buf
401  /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
402  /// for %j = 0 to 4 {
403  /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
404  /// : memref<?x?x?xf32>, vector<3xf32>
405  /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
406  /// }
407  /// ```
408  ///
409  /// Note: The loop and type cast are generated in TransferOpConversion.
410  /// The original TransferReadOp and store op are deleted in `cleanup`.
411  /// Note: The `mask` operand is set in TransferOpConversion.
412  static TransferReadOp rewriteOp(OpBuilder &b,
414  TransferReadOp xferOp, Value buffer, Value iv,
415  ValueRange /*loopState*/) {
416  SmallVector<Value, 8> storeIndices;
417  getBufferIndices(xferOp, storeIndices);
418  storeIndices.push_back(iv);
419 
420  SmallVector<Value, 8> xferIndices;
421  getXferIndices(b, xferOp, iv, xferIndices);
422 
423  Location loc = xferOp.getLoc();
424  auto bufferType = dyn_cast<ShapedType>(buffer.getType());
425  auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
426  auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
427  auto newXferOp = b.create<vector::TransferReadOp>(
428  loc, vecType, xferOp.getBase(), xferIndices,
429  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
430  xferOp.getPadding(), Value(), inBoundsAttr);
431 
432  maybeApplyPassLabel(b, newXferOp, options.targetRank);
433 
434  b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
435  return newXferOp;
436  }
437 
438  /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
439  /// padding value to the temporary buffer.
440  static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
441  Value buffer, Value iv,
442  ValueRange /*loopState*/) {
443  SmallVector<Value, 8> storeIndices;
444  getBufferIndices(xferOp, storeIndices);
445  storeIndices.push_back(iv);
446 
447  Location loc = xferOp.getLoc();
448  auto bufferType = dyn_cast<ShapedType>(buffer.getType());
449  auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
450  auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
451  b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
452 
453  return Value();
454  }
455 
456  /// Cleanup after rewriting the op.
457  static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
458  scf::ForOp /*forOp*/) {
459  rewriter.eraseOp(getStoreOp(xferOp));
460  rewriter.eraseOp(xferOp);
461  }
462 
463  /// Return the initial loop state for the generated scf.for loop.
464  static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
465 };
466 
467 /// Codegen strategy for vector TransferWriteOp.
468 template <>
469 struct Strategy<TransferWriteOp> {
470  /// Find the temporary buffer allocation. All labeled TransferWriteOps are
471  /// used like this, where %buf is either the buffer allocation or a type cast
472  /// of the buffer allocation:
473  /// ```
474  /// %vec = memref.load %buf[...] ...
475  /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
476  /// ```
477  static Value getBuffer(TransferWriteOp xferOp) {
478  auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
479  assert(loadOp && "Expected transfer op vector produced by LoadOp");
480  return loadOp.getMemRef();
481  }
482 
483  /// Retrieve the indices of the current LoadOp that loads from the buffer.
484  static void getBufferIndices(TransferWriteOp xferOp,
485  SmallVector<Value, 8> &indices) {
486  auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
487  auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
488  indices.append(prevIndices.begin(), prevIndices.end());
489  }
490 
491  /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
492  /// accesses on the to-be-unpacked dimension.
493  ///
494  /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
495  /// using the loop iteration variable `iv`.
496  /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
497  /// to memory.
498  ///
499  /// Note: For more details, see comments on Strategy<TransferReadOp>.
500  static TransferWriteOp rewriteOp(OpBuilder &b,
502  TransferWriteOp xferOp, Value buffer,
503  Value iv, ValueRange loopState) {
504  SmallVector<Value, 8> loadIndices;
505  getBufferIndices(xferOp, loadIndices);
506  loadIndices.push_back(iv);
507 
508  SmallVector<Value, 8> xferIndices;
509  getXferIndices(b, xferOp, iv, xferIndices);
510 
511  Location loc = xferOp.getLoc();
512  auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
513  auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
514  auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
515  Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
516  auto newXferOp = b.create<vector::TransferWriteOp>(
517  loc, type, vec, source, xferIndices,
518  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
519  inBoundsAttr);
520 
521  maybeApplyPassLabel(b, newXferOp, options.targetRank);
522 
523  return newXferOp;
524  }
525 
526  /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
527  static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
528  Value buffer, Value iv,
529  ValueRange loopState) {
530  return isTensorOp(xferOp) ? loopState[0] : Value();
531  }
532 
533  /// Cleanup after rewriting the op.
534  static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
535  scf::ForOp forOp) {
536  if (isTensorOp(xferOp)) {
537  assert(forOp->getNumResults() == 1 && "Expected one for loop result");
538  rewriter.replaceOp(xferOp, forOp->getResult(0));
539  } else {
540  rewriter.eraseOp(xferOp);
541  }
542  }
543 
544  /// Return the initial loop state for the generated scf.for loop.
545  static Value initialLoopState(TransferWriteOp xferOp) {
546  return isTensorOp(xferOp) ? xferOp.getBase() : Value();
547  }
548 };
549 
550 template <typename OpTy>
551 static LogicalResult checkPrepareXferOp(OpTy xferOp, PatternRewriter &rewriter,
553  if (xferOp->hasAttr(kPassLabel))
554  return rewriter.notifyMatchFailure(
555  xferOp, "kPassLabel is present (vector-to-scf lowering in progress)");
556  if (xferOp.getVectorType().getRank() <= options.targetRank)
557  return rewriter.notifyMatchFailure(
558  xferOp, "xferOp vector rank <= transformation target rank");
559  if (xferOp.getVectorType().getScalableDims().front())
560  return rewriter.notifyMatchFailure(
561  xferOp, "Unpacking of the leading dimension into the memref is not yet "
562  "supported for scalable dims");
563  if (isTensorOp(xferOp) && !options.lowerTensors)
564  return rewriter.notifyMatchFailure(
565  xferOp, "Unpacking for tensors has been disabled.");
566  if (xferOp.getVectorType().getElementType() !=
567  xferOp.getShapedType().getElementType())
568  return rewriter.notifyMatchFailure(
569  xferOp, "Mismatching source and destination element types.");
570 
571  return success();
572 }
573 
574 /// Prepare a TransferReadOp for progressive lowering.
575 ///
576 /// 1. Allocate a temporary buffer.
577 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
578 /// 3. Store the result of the TransferReadOp into the temporary buffer.
579 /// 4. Load the result from the temporary buffer and replace all uses of the
580 /// original TransferReadOp with this load.
581 ///
582 /// E.g.:
583 /// ```
584 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
585 /// : vector<5x4xf32>, memref<?x?x?xf32>
586 /// ```
587 /// is rewritten to:
588 /// ```
589 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
590 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
591 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
592 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
593 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
594 /// ```
595 ///
596 /// Note: A second temporary buffer may be allocated for the `mask` operand.
597 struct PrepareTransferReadConversion
598  : public VectorToSCFPattern<TransferReadOp> {
599  using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
600 
601  LogicalResult matchAndRewrite(TransferReadOp xferOp,
602  PatternRewriter &rewriter) const override {
603  if (checkPrepareXferOp(xferOp, rewriter, options).failed())
604  return rewriter.notifyMatchFailure(
605  xferOp, "checkPrepareXferOp conditions not met!");
606 
607  auto buffers = allocBuffers(rewriter, xferOp);
608  auto *newXfer = rewriter.clone(*xferOp.getOperation());
609  newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
610  if (xferOp.getMask()) {
611  dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
612  buffers.maskBuffer);
613  }
614 
615  Location loc = xferOp.getLoc();
616  rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
617  buffers.dataBuffer);
618  rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
619 
620  return success();
621  }
622 };
623 
624 /// Prepare a TransferWriteOp for progressive lowering.
625 ///
626 /// 1. Allocate a temporary buffer.
627 /// 2. Store the vector into the buffer.
628 /// 3. Load the vector from the buffer again.
629 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
630 /// marking it eligible for progressive lowering via TransferOpConversion.
631 ///
632 /// E.g.:
633 /// ```
634 /// vector.transfer_write %vec, %A[%a, %b, %c]
635 /// : vector<5x4xf32>, memref<?x?x?xf32>
636 /// ```
637 /// is rewritten to:
638 /// ```
639 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
640 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
641 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
642 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
643 /// : vector<5x4xf32>, memref<?x?x?xf32>
644 /// ```
645 ///
646 /// Note: A second temporary buffer may be allocated for the `mask` operand.
647 struct PrepareTransferWriteConversion
648  : public VectorToSCFPattern<TransferWriteOp> {
649  using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
650 
651  LogicalResult matchAndRewrite(TransferWriteOp xferOp,
652  PatternRewriter &rewriter) const override {
653  if (checkPrepareXferOp(xferOp, rewriter, options).failed())
654  return rewriter.notifyMatchFailure(
655  xferOp, "checkPrepareXferOp conditions not met!");
656 
657  Location loc = xferOp.getLoc();
658  auto buffers = allocBuffers(rewriter, xferOp);
659  rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
660  buffers.dataBuffer);
661  auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
662  rewriter.modifyOpInPlace(xferOp, [&]() {
663  xferOp.getValueToStoreMutable().assign(loadedVec);
664  xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
665  });
666 
667  if (xferOp.getMask()) {
668  rewriter.modifyOpInPlace(xferOp, [&]() {
669  xferOp.getMaskMutable().assign(buffers.maskBuffer);
670  });
671  }
672 
673  return success();
674  }
675 };
676 
677 /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows
678 /// printing both 1D scalable vectors and n-D fixed size vectors.
679 ///
680 /// E.g.:
681 /// ```
682 /// vector.print %v : vector<[4]xi32>
683 /// ```
684 /// is rewritten to:
685 /// ```
686 /// %c0 = arith.constant 0 : index
687 /// %c4 = arith.constant 4 : index
688 /// %c1 = arith.constant 1 : index
689 /// %vscale = vector.vscale
690 /// %length = arith.muli %vscale, %c4 : index
691 /// %lastIndex = arith.subi %length, %c1 : index
692 /// vector.print punctuation <open>
693 /// scf.for %i = %c0 to %length step %c1 {
694 /// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
695 /// vector.print %el : i32 punctuation <no_punctuation>
696 /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
697 /// scf.if %notLastIndex {
698 /// vector.print punctuation <comma>
699 /// }
700 /// }
701 /// vector.print punctuation <close>
702 /// vector.print
703 /// ```
704 struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
705  using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
706  LogicalResult matchAndRewrite(vector::PrintOp printOp,
707  PatternRewriter &rewriter) const override {
708  if (!printOp.getSource())
709  return failure();
710 
711  VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
712  if (!vectorType)
713  return failure();
714 
715  // Currently >= 2D scalable vectors are not supported.
716  // These can't be lowered to LLVM (as LLVM does not support scalable vectors
717  // of scalable vectors), and due to limitations of current ops can't be
718  // indexed with SSA values or flattened. This may change after
719  // https://reviews.llvm.org/D155034, though there still needs to be a path
720  // for lowering to LLVM.
721  if (vectorType.getRank() > 1 && vectorType.isScalable())
722  return failure();
723 
724  auto loc = printOp.getLoc();
725  auto value = printOp.getSource();
726 
727  if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
728  // Oddly sized integers are (somewhat) buggy on a lot of backends, so to
729  // avoid issues extend them to a more standard size.
730  // https://github.com/llvm/llvm-project/issues/30613
731  auto width = intTy.getWidth();
732  auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1);
733  auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
734  intTy.getSignedness());
735  // arith can only take signless integers, so we must cast back and forth.
736  auto signlessSourceVectorType =
737  vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
738  auto signlessTargetVectorType =
739  vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
740  auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
741  value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
742  value);
743  if (value.getType() != signlessTargetVectorType) {
744  if (width == 1 || intTy.isUnsigned())
745  value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
746  value);
747  else
748  value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
749  value);
750  }
751  value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
752  vectorType = targetVectorType;
753  }
754 
755  auto scalableDimensions = vectorType.getScalableDims();
756  auto shape = vectorType.getShape();
757  constexpr int64_t singletonShape[] = {1};
758  if (vectorType.getRank() == 0)
759  shape = singletonShape;
760 
761  if (vectorType.getRank() != 1) {
762  // Flatten n-D vectors to 1D. This is done to allow indexing with a
763  // non-constant value (which can currently only be done via
764  // vector.extractelement for 1D vectors).
765  auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
766  std::multiplies<int64_t>());
767  auto flatVectorType =
768  VectorType::get({flatLength}, vectorType.getElementType());
769  value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
770  }
771 
772  vector::PrintOp firstClose;
773  SmallVector<Value, 8> loopIndices;
774  for (unsigned d = 0; d < shape.size(); d++) {
775  // Setup loop bounds and step.
776  Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
777  Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]);
778  Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
779  if (!scalableDimensions.empty() && scalableDimensions[d]) {
780  auto vscale = rewriter.create<vector::VectorScaleOp>(
781  loc, rewriter.getIndexType());
782  upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale);
783  }
784  auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step);
785 
786  // Create a loop to print the elements surrounded by parentheses.
787  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
788  auto loop =
789  rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
790  auto printClose = rewriter.create<vector::PrintOp>(
791  loc, vector::PrintPunctuation::Close);
792  if (!firstClose)
793  firstClose = printClose;
794 
795  auto loopIdx = loop.getInductionVar();
796  loopIndices.push_back(loopIdx);
797 
798  // Print a comma after all but the last element.
799  rewriter.setInsertionPointToStart(loop.getBody());
800  auto notLastIndex = rewriter.create<arith::CmpIOp>(
801  loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
802  rewriter.create<scf::IfOp>(loc, notLastIndex,
803  [&](OpBuilder &builder, Location loc) {
804  builder.create<vector::PrintOp>(
805  loc, vector::PrintPunctuation::Comma);
806  builder.create<scf::YieldOp>(loc);
807  });
808 
809  rewriter.setInsertionPointToStart(loop.getBody());
810  }
811 
812  // Compute the flattened index.
813  // Note: For the > rank 1 vectors this assumes non-scalable.
814  Value flatIndex;
815  auto currentStride = 1;
816  for (int d = shape.size() - 1; d >= 0; d--) {
817  auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride);
818  auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]);
819  if (flatIndex)
820  flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index);
821  else
822  flatIndex = index;
823  currentStride *= shape[d];
824  }
825 
826  // Print the scalar elements in the inner most loop.
827  auto element =
828  rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
829  rewriter.create<vector::PrintOp>(loc, element,
830  vector::PrintPunctuation::NoPunctuation);
831 
832  rewriter.setInsertionPointAfter(firstClose);
833  rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation());
834  rewriter.eraseOp(printOp);
835  return success();
836  }
837 
838  static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
839  return IntegerType::get(intTy.getContext(), intTy.getWidth(),
840  IntegerType::Signless);
841  };
842 };
843 
844 /// Progressive lowering of vector transfer ops: Unpack one dimension.
845 ///
846 /// 1. Unpack one dimension from the current buffer type and cast the buffer
847 /// to that new type. E.g.:
848 /// ```
849 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
850 /// vector.transfer_write %vec ...
851 /// ```
852 /// The following cast is generated:
853 /// ```
854 /// %casted = vector.type_cast %0
855 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
856 /// ```
857 /// 2. Generate a for loop and rewrite the transfer op according to the
858 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
859 /// out-of-bounds, generate an if-check and handle both cases separately.
860 /// 3. Clean up according to the corresponding Strategy<OpTy>.
861 ///
862 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
863 /// source (as opposed to a memref source), then each iteration of the generated
864 /// scf.for loop yields the new tensor value. E.g.:
865 /// ```
866 /// %result = scf.for i = 0 to 5 {
867 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
868 /// %1 = vector.transfer_write %0, %source[...]
869 /// : vector<4x3xf32>, tensor<5x4x3xf32>
870 /// scf.yield %1 : tensor<5x4x3xf32>
871 /// }
872 /// ```
873 template <typename OpTy>
874 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
875  using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
876 
877  void initialize() {
878  // This pattern recursively unpacks one dimension at a time. The recursion
879  // bounded as the rank is strictly decreasing.
880  this->setHasBoundedRewriteRecursion();
881  }
882 
883  static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
884  SmallVectorImpl<Value> &loadIndices,
885  Value iv) {
886  assert(xferOp.getMask() && "Expected transfer op to have mask");
887 
888  // Add load indices from the previous iteration.
889  // The mask buffer depends on the permutation map, which makes determining
890  // the indices quite complex, so this is why we need to "look back" to the
891  // previous iteration to find the right indices.
892  Value maskBuffer = getMaskBuffer(xferOp);
893  for (Operation *user : maskBuffer.getUsers()) {
894  // If there is no previous load op, then the indices are empty.
895  if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
896  Operation::operand_range prevIndices = loadOp.getIndices();
897  loadIndices.append(prevIndices.begin(), prevIndices.end());
898  break;
899  }
900  }
901 
902  // In case of broadcast: Use same indices to load from memref
903  // as before.
904  if (!xferOp.isBroadcastDim(0))
905  loadIndices.push_back(iv);
906  }
907 
908  LogicalResult matchAndRewrite(OpTy xferOp,
909  PatternRewriter &rewriter) const override {
910  if (!xferOp->hasAttr(kPassLabel))
911  return rewriter.notifyMatchFailure(
912  xferOp, "kPassLabel is present (progressing lowering in progress)");
913 
914  // Find and cast data buffer. How the buffer can be found depends on OpTy.
915  ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
916  Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
917  auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
918  FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
919  if (failed(castedDataType))
920  return rewriter.notifyMatchFailure(xferOp,
921  "Failed to unpack one vector dim.");
922 
923  auto castedDataBuffer =
924  locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
925 
926  // If the xferOp has a mask: Find and cast mask buffer.
927  Value castedMaskBuffer;
928  if (xferOp.getMask()) {
929  Value maskBuffer = getMaskBuffer(xferOp);
930  if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
931  // Do not unpack a dimension of the mask, if:
932  // * To-be-unpacked transfer op dimension is a broadcast.
933  // * Mask is 1D, i.e., the mask cannot be further unpacked.
934  // (That means that all remaining dimensions of the transfer op must
935  // be broadcasted.)
936  castedMaskBuffer = maskBuffer;
937  } else {
938  // It's safe to assume the mask buffer can be unpacked if the data
939  // buffer was unpacked.
940  auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
941  MemRefType castedMaskType = *unpackOneDim(maskBufferType);
942  castedMaskBuffer =
943  locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
944  }
945  }
946 
947  // Loop bounds and step.
948  auto lb = locB.create<arith::ConstantIndexOp>(0);
949  auto ub = locB.create<arith::ConstantIndexOp>(
950  castedDataType->getDimSize(castedDataType->getRank() - 1));
951  auto step = locB.create<arith::ConstantIndexOp>(1);
952  // TransferWriteOps that operate on tensors return the modified tensor and
953  // require a loop state.
954  auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
955 
956  // Generate for loop.
957  auto result = locB.create<scf::ForOp>(
958  lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
959  [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
960  Type stateType = loopState.empty() ? Type() : loopState[0].getType();
961 
962  auto result = generateInBoundsCheck(
963  b, xferOp, iv, unpackedDim(xferOp),
964  stateType ? TypeRange(stateType) : TypeRange(),
965  /*inBoundsCase=*/
966  [&](OpBuilder &b, Location loc) {
967  // Create new transfer op.
968  OpTy newXfer = Strategy<OpTy>::rewriteOp(
969  b, this->options, xferOp, castedDataBuffer, iv, loopState);
970 
971  // If old transfer op has a mask: Set mask on new transfer op.
972  // Special case: If the mask of the old transfer op is 1D and
973  // the unpacked dim is not a broadcast, no mask is needed on
974  // the new transfer op.
975  if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
976  xferOp.getMaskType().getRank() > 1)) {
977  OpBuilder::InsertionGuard guard(b);
978  b.setInsertionPoint(newXfer); // Insert load before newXfer.
979 
980  SmallVector<Value, 8> loadIndices;
981  getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
982  loadIndices, iv);
983  auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
984  loadIndices);
985  rewriter.modifyOpInPlace(newXfer, [&]() {
986  newXfer.getMaskMutable().assign(mask);
987  });
988  }
989 
990  return loopState.empty() ? Value() : newXfer->getResult(0);
991  },
992  /*outOfBoundsCase=*/
993  [&](OpBuilder &b, Location /*loc*/) {
994  return Strategy<OpTy>::handleOutOfBoundsDim(
995  b, xferOp, castedDataBuffer, iv, loopState);
996  });
997 
998  maybeYieldValue(b, loc, !loopState.empty(), result);
999  });
1000 
1001  Strategy<OpTy>::cleanup(rewriter, xferOp, result);
1002  return success();
1003  }
1004 };
1005 
1006 /// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
1007 /// and ConstantMaskOp.
1008 template <typename VscaleConstantBuilder>
1009 static FailureOr<SmallVector<OpFoldResult>>
1010 getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1011  if (!mask)
1012  return SmallVector<OpFoldResult>{};
1013  if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
1014  return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
1015  return OpFoldResult(dimSize);
1016  });
1017  }
1018  if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
1019  int dimIdx = 0;
1020  VectorType maskType = constantMask.getVectorType();
1021  auto indexType = IndexType::get(mask.getContext());
1022  return llvm::map_to_vector(
1023  constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1024  // A scalable dim in a constant_mask means vscale x dimSize.
1025  if (maskType.getScalableDims()[dimIdx++])
1026  return OpFoldResult(createVscaleMultiple(dimSize));
1027  return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1028  });
1029  }
1030  return failure();
1031 }
1032 
1033 /// Scalable vector lowering of transfer_write(transpose). This lowering only
1034 /// supports rank 2 (scalable) vectors, but can be used in conjunction with
1035 /// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
1036 /// unrolls until the first scalable dimension.
1037 ///
1038 /// Example:
1039 ///
1040 /// BEFORE:
1041 /// ```mlir
1042 /// %transpose = vector.transpose %vec, [1, 0]
1043 /// : vector<4x[4]xf32> to vector<[4]x4xf32>
1044 /// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
1045 /// : vector<[4]x4xf32>, memref<?x?xf32>
1046 /// ```
1047 ///
1048 /// AFTER:
1049 /// ```mlir
1050 /// %c1 = arith.constant 1 : index
1051 /// %c4 = arith.constant 4 : index
1052 /// %c0 = arith.constant 0 : index
1053 /// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
1054 /// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
1055 /// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
1056 /// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
1057 /// %vscale = vector.vscale
1058 /// %c4_vscale = arith.muli %vscale, %c4 : index
1059 /// scf.for %idx = %c0 to %c4_vscale step %c1 {
1060 /// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1061 /// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1062 /// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1063 /// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1064 /// %slice_i = affine.apply #map(%idx)[%i]
1065 /// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
1066 /// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
1067 /// : vector<4xf32>, memref<?x?xf32>
1068 /// }
1069 /// ```
1070 struct ScalableTransposeTransferWriteConversion
1071  : VectorToSCFPattern<vector::TransferWriteOp> {
1072  using VectorToSCFPattern::VectorToSCFPattern;
1073 
1074  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1075  PatternRewriter &rewriter) const override {
1076  if (failed(checkLowerTensors(writeOp, rewriter)))
1077  return failure();
1078 
1079  VectorType vectorType = writeOp.getVectorType();
1080 
1081  // Note: By comparing the scalable dims to an ArrayRef of length two this
1082  // implicitly checks the rank (is also two).
1083  ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
1084  if (scalableFlags != ArrayRef<bool>{true, false}) {
1085  return rewriter.notifyMatchFailure(
1086  writeOp, "expected vector of the form vector<[N]xMxty>");
1087  }
1088 
1089  auto permutationMap = writeOp.getPermutationMap();
1090  if (!permutationMap.isIdentity()) {
1091  return rewriter.notifyMatchFailure(
1092  writeOp, "non-identity permutations are unsupported (lower first)");
1093  }
1094 
1095  // Note: This pattern is only lowering the leading dimension (to a loop),
1096  // so we only check if the leading dimension is in bounds. The in-bounds
1097  // attribute for the trailing dimension will be propagated.
1098  if (!writeOp.isDimInBounds(0)) {
1099  return rewriter.notifyMatchFailure(
1100  writeOp, "out-of-bounds dims are unsupported (use masking)");
1101  }
1102 
1103  Value vector = writeOp.getVector();
1104  auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
1105  if (!transposeOp ||
1106  transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
1107  return rewriter.notifyMatchFailure(writeOp, "source not transpose");
1108  }
1109 
1110  auto loc = writeOp.getLoc();
1111  auto createVscaleMultiple =
1112  vector::makeVscaleConstantBuilder(rewriter, loc);
1113 
1114  auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1115  if (failed(maskDims)) {
1116  return rewriter.notifyMatchFailure(writeOp,
1117  "failed to resolve mask dims");
1118  }
1119 
1120  int64_t fixedDimSize = vectorType.getDimSize(1);
1121  auto fixedDimOffsets = llvm::seq(fixedDimSize);
1122 
1123  // Extract all slices from the source of the transpose.
1124  auto transposeSource = transposeOp.getVector();
1125  SmallVector<Value> transposeSourceSlices =
1126  llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1127  return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
1128  });
1129 
1130  // Loop bounds and step.
1131  auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1132  auto ub =
1133  maskDims->empty()
1134  ? Value(createVscaleMultiple(vectorType.getDimSize(0)))
1135  : vector::getAsValues(rewriter, loc, maskDims->front()).front();
1136  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1137 
1138  // Generate a new mask for the slice.
1139  VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
1140  Value sliceMask = nullptr;
1141  if (!maskDims->empty()) {
1142  sliceMask = rewriter.create<vector::CreateMaskOp>(
1143  loc, sliceType.clone(rewriter.getI1Type()),
1144  ArrayRef<OpFoldResult>(*maskDims).drop_front());
1145  }
1146 
1147  Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{};
1148  ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
1149  auto result = rewriter.create<scf::ForOp>(
1150  loc, lb, ub, step, initLoopArgs,
1151  [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
1152  // Indices for the new transfer op.
1153  SmallVector<Value, 8> xferIndices;
1154  getXferIndices(b, writeOp, iv, xferIndices);
1155 
1156  // Extract a transposed slice from the source vector.
1157  SmallVector<Value> transposeElements =
1158  llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1159  return b.create<vector::ExtractOp>(
1160  loc, transposeSourceSlices[idx], iv);
1161  });
1162  auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
1163  transposeElements);
1164 
1165  // Create the transfer_write for the slice.
1166  Value dest =
1167  loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
1168  auto newWriteOp = b.create<vector::TransferWriteOp>(
1169  loc, sliceVec, dest, xferIndices,
1170  ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
1171  if (sliceMask)
1172  newWriteOp.getMaskMutable().assign(sliceMask);
1173 
1174  // Yield from the loop.
1175  b.create<scf::YieldOp>(loc, loopIterArgs.empty()
1176  ? ValueRange{}
1177  : newWriteOp.getResult());
1178  });
1179 
1180  if (isTensorOp(writeOp))
1181  rewriter.replaceOp(writeOp, result);
1182  else
1183  rewriter.eraseOp(writeOp);
1184 
1185  return success();
1186  }
1187 };
1188 
1189 } // namespace lowering_n_d
1190 
1192 
1193 /// If the original transfer op has a mask, compute the mask of the new transfer
1194 /// op (for the current iteration `i`) and assign it.
1195 template <typename OpTy>
1196 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1197  int64_t i) {
1198  if (!xferOp.getMask())
1199  return;
1200 
1201  if (xferOp.isBroadcastDim(0)) {
1202  // To-be-unpacked dimension is a broadcast, which does not have a
1203  // corresponding mask dimension. Mask attribute remains unchanged.
1204  newXferOp.getMaskMutable().assign(xferOp.getMask());
1205  return;
1206  }
1207 
1208  if (xferOp.getMaskType().getRank() > 1) {
1209  // Unpack one dimension of the mask.
1210  OpBuilder::InsertionGuard guard(b);
1211  b.setInsertionPoint(newXferOp); // Insert load before newXfer.
1212 
1213  llvm::SmallVector<int64_t, 1> indices({i});
1214  Location loc = xferOp.getLoc();
1215  auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
1216  newXferOp.getMaskMutable().assign(newMask);
1217  }
1218 
1219  // If we end up here: The mask of the old transfer op is 1D and the unpacked
1220  // dim is not a broadcast, so no mask is needed on the new transfer op.
1221  // `generateInBoundsCheck` will have evaluated the mask already.
1222 }
1223 
1224 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
1225 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
1226 /// memref buffer is allocated and the SCF loop is fully unrolled.
1227 ///
1228 /// ```
1229 /// E.g.:
1230 /// ```
1231 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
1232 /// : memref<?x?x?xf32>, vector<5x4xf32>
1233 /// ```
1234 /// is rewritten to IR such as (simplified):
1235 /// ```
1236 /// %v_init = splat %padding : vector<5x4xf32>
1237 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
1238 /// : memref<?x?x?xf32>, vector<4xf32>
1239 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
1240 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
1241 /// : memref<?x?x?xf32>, vector<4xf32>
1242 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
1243 /// ...
1244 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
1245 /// : memref<?x?x?xf32>, vector<4xf32>
1246 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
1247 /// ```
1248 ///
1249 /// Note: As an optimization, if the result of the original TransferReadOp
1250 /// was directly inserted into another vector, no new %v_init vector is created.
1251 /// Instead, the new TransferReadOp results are inserted into that vector.
1252 struct UnrollTransferReadConversion
1253  : public VectorToSCFPattern<TransferReadOp> {
1254  using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1255 
1256  void initialize() {
1257  // This pattern recursively unpacks one dimension at a time. The recursion
1258  // bounded as the rank is strictly decreasing.
1259  setHasBoundedRewriteRecursion();
1260  }
1261 
1262  /// Get or build the vector into which the newly created TransferReadOp
1263  /// results are inserted.
1264  Value buildResultVector(PatternRewriter &rewriter,
1265  TransferReadOp xferOp) const {
1266  if (auto insertOp = getInsertOp(xferOp))
1267  return insertOp.getDest();
1268  Location loc = xferOp.getLoc();
1269  return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1270  xferOp.getPadding());
1271  }
1272 
1273  /// If the result of the TransferReadOp has exactly one user, which is a
1274  /// vector::InsertOp, return that operation.
1275  vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
1276  if (xferOp->hasOneUse()) {
1277  Operation *xferOpUser = *xferOp->getUsers().begin();
1278  if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1279  return insertOp;
1280  }
1281 
1282  return vector::InsertOp();
1283  }
1284 
1285  /// If the result of the TransferReadOp has exactly one user, which is a
1286  /// vector::InsertOp, return that operation's indices.
1287  void getInsertionIndices(TransferReadOp xferOp,
1288  SmallVectorImpl<OpFoldResult> &indices) const {
1289  if (auto insertOp = getInsertOp(xferOp)) {
1290  auto pos = insertOp.getMixedPosition();
1291  indices.append(pos.begin(), pos.end());
1292  }
1293  }
1294 
1295  /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1296  /// accesses, and broadcasts and transposes in permutation maps.
1297  LogicalResult matchAndRewrite(TransferReadOp xferOp,
1298  PatternRewriter &rewriter) const override {
1299  if (xferOp.getVectorType().getRank() <= options.targetRank)
1300  return rewriter.notifyMatchFailure(
1301  xferOp, "vector rank is less or equal to target rank");
1302  if (failed(checkLowerTensors(xferOp, rewriter)))
1303  return failure();
1304  if (xferOp.getVectorType().getElementType() !=
1305  xferOp.getShapedType().getElementType())
1306  return rewriter.notifyMatchFailure(
1307  xferOp, "not yet supported: element type mismatch");
1308  auto xferVecType = xferOp.getVectorType();
1309  if (xferVecType.getScalableDims()[0]) {
1310  return rewriter.notifyMatchFailure(
1311  xferOp, "scalable dimensions cannot be unrolled at compile time");
1312  }
1313 
1314  auto insertOp = getInsertOp(xferOp);
1315  auto vec = buildResultVector(rewriter, xferOp);
1316  auto vecType = dyn_cast<VectorType>(vec.getType());
1317 
1318  VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
1319 
1320  int64_t dimSize = xferVecType.getShape()[0];
1321 
1322  // Generate fully unrolled loop of transfer ops.
1323  Location loc = xferOp.getLoc();
1324  for (int64_t i = 0; i < dimSize; ++i) {
1325  Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1326 
1327  vec = generateInBoundsCheck(
1328  rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
1329  /*inBoundsCase=*/
1330  [&](OpBuilder &b, Location loc) {
1331  // Indices for the new transfer op.
1332  SmallVector<Value, 8> xferIndices;
1333  getXferIndices(b, xferOp, iv, xferIndices);
1334 
1335  // Indices for the new vector.insert op.
1336  SmallVector<OpFoldResult, 8> insertionIndices;
1337  getInsertionIndices(xferOp, insertionIndices);
1338  insertionIndices.push_back(rewriter.getIndexAttr(i));
1339 
1340  auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1341  auto newXferOp = b.create<vector::TransferReadOp>(
1342  loc, newXferVecType, xferOp.getBase(), xferIndices,
1343  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1344  xferOp.getPadding(), Value(), inBoundsAttr);
1345  maybeAssignMask(b, xferOp, newXferOp, i);
1346  return b.create<vector::InsertOp>(loc, newXferOp, vec,
1347  insertionIndices);
1348  },
1349  /*outOfBoundsCase=*/
1350  [&](OpBuilder &b, Location loc) {
1351  // Loop through original (unmodified) vector.
1352  return vec;
1353  });
1354  }
1355 
1356  if (insertOp) {
1357  // Rewrite single user of the old TransferReadOp, which was an InsertOp.
1358  rewriter.replaceOp(insertOp, vec);
1359  rewriter.eraseOp(xferOp);
1360  } else {
1361  rewriter.replaceOp(xferOp, vec);
1362  }
1363 
1364  return success();
1365  }
1366 };
1367 
1368 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
1369 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
1370 /// memref buffer is allocated and the SCF loop is fully unrolled.
1371 ///
1372 /// ```
1373 /// E.g.:
1374 /// ```
1375 /// vector.transfer_write %vec, %A[%a, %b, %c]
1376 /// : vector<5x4xf32>, memref<?x?x?xf32>
1377 /// ```
1378 /// is rewritten to IR such as (simplified):
1379 /// ```
1380 /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32>
1381 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
1382 /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32>
1383 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
1384 /// ...
1385 /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32>
1386 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
1387 /// ```
1388 ///
1389 /// Note: As an optimization, if the vector of the original TransferWriteOp
1390 /// was directly extracted from another vector via an ExtractOp `a`, extract
1391 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
1392 /// doing so, `a` may become dead, and the number of ExtractOps generated during
1393 /// recursive application of this pattern will be minimal.
1394 struct UnrollTransferWriteConversion
1395  : public VectorToSCFPattern<TransferWriteOp> {
1396  using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1397 
1398  void initialize() {
1399  // This pattern recursively unpacks one dimension at a time. The recursion
1400  // bounded as the rank is strictly decreasing.
1401  setHasBoundedRewriteRecursion();
1402  }
1403 
1404  /// Return the vector from which newly generated ExtracOps will extract.
1405  Value getDataVector(TransferWriteOp xferOp) const {
1406  if (auto extractOp = getExtractOp(xferOp))
1407  return extractOp.getVector();
1408  return xferOp.getVector();
1409  }
1410 
1411  /// If the input of the given TransferWriteOp is an ExtractOp, return it.
1412  vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
1413  if (auto *op = xferOp.getVector().getDefiningOp())
1414  return dyn_cast<vector::ExtractOp>(op);
1415  return vector::ExtractOp();
1416  }
1417 
1418  /// If the input of the given TransferWriteOp is an ExtractOp, return its
1419  /// indices.
1420  void getExtractionIndices(TransferWriteOp xferOp,
1421  SmallVectorImpl<OpFoldResult> &indices) const {
1422  if (auto extractOp = getExtractOp(xferOp)) {
1423  auto pos = extractOp.getMixedPosition();
1424  indices.append(pos.begin(), pos.end());
1425  }
1426  }
1427 
1428  /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1429  /// accesses, and broadcasts and transposes in permutation maps.
1430  LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1431  PatternRewriter &rewriter) const override {
1432  VectorType inputVectorTy = xferOp.getVectorType();
1433 
1434  if (inputVectorTy.getRank() <= options.targetRank)
1435  return failure();
1436 
1437  if (failed(checkLowerTensors(xferOp, rewriter)))
1438  return failure();
1439  // Transfer ops that modify the element type are not supported atm.
1440  if (inputVectorTy.getElementType() !=
1441  xferOp.getShapedType().getElementType())
1442  return failure();
1443 
1444  auto vec = getDataVector(xferOp);
1445  if (inputVectorTy.getScalableDims()[0]) {
1446  // Cannot unroll a scalable dimension at compile time.
1447  return failure();
1448  }
1449 
1450  int64_t dimSize = inputVectorTy.getShape()[0];
1451  Value source = xferOp.getBase(); // memref or tensor to be written to.
1452  auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1453 
1454  // Generate fully unrolled loop of transfer ops.
1455  Location loc = xferOp.getLoc();
1456  for (int64_t i = 0; i < dimSize; ++i) {
1457  Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1458 
1459  auto updatedSource = generateInBoundsCheck(
1460  rewriter, xferOp, iv, unpackedDim(xferOp),
1461  isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1462  /*inBoundsCase=*/
1463  [&](OpBuilder &b, Location loc) {
1464  // Indices for the new transfer op.
1465  SmallVector<Value, 8> xferIndices;
1466  getXferIndices(b, xferOp, iv, xferIndices);
1467 
1468  // Indices for the new vector.extract op.
1469  SmallVector<OpFoldResult, 8> extractionIndices;
1470  getExtractionIndices(xferOp, extractionIndices);
1471  extractionIndices.push_back(b.getI64IntegerAttr(i));
1472 
1473  auto extracted =
1474  b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1475  auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1476  Value xferVec;
1477  if (inputVectorTy.getRank() == 1) {
1478  // When target-rank=0, unrolling would causes the vector input
1479  // argument into `transfer_write` to become a scalar. We solve
1480  // this by broadcasting the scalar to a 0D vector.
1481  xferVec = b.create<vector::BroadcastOp>(
1482  loc, VectorType::get({}, extracted.getType()), extracted);
1483  } else {
1484  xferVec = extracted;
1485  }
1486  auto newXferOp = b.create<vector::TransferWriteOp>(
1487  loc, sourceType, xferVec, source, xferIndices,
1488  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1489  inBoundsAttr);
1490 
1491  maybeAssignMask(b, xferOp, newXferOp, i);
1492 
1493  return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1494  },
1495  /*outOfBoundsCase=*/
1496  [&](OpBuilder &b, Location loc) {
1497  return isTensorOp(xferOp) ? source : Value();
1498  });
1499 
1500  if (isTensorOp(xferOp))
1501  source = updatedSource;
1502  }
1503 
1504  if (isTensorOp(xferOp))
1505  rewriter.replaceOp(xferOp, source);
1506  else
1507  rewriter.eraseOp(xferOp);
1508 
1509  return success();
1510  }
1511 };
1512 
1513 } // namespace lowering_n_d_unrolled
1514 
1515 namespace lowering_1_d {
1516 
1517 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1518 /// part of TransferOp1dConversion. Return the memref dimension on which
1519 /// the transfer is operating. A return value of std::nullopt indicates a
1520 /// broadcast.
1521 template <typename OpTy>
1522 static std::optional<int64_t>
1523 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1524  SmallVector<Value, 8> &memrefIndices) {
1525  auto indices = xferOp.getIndices();
1526  auto map = xferOp.getPermutationMap();
1527  assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1528 
1529  memrefIndices.append(indices.begin(), indices.end());
1530  assert(map.getNumResults() == 1 &&
1531  "Expected 1 permutation map result for 1D transfer");
1532  if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1533  Location loc = xferOp.getLoc();
1534  auto dim = expr.getPosition();
1535  AffineExpr d0, d1;
1536  bindDims(xferOp.getContext(), d0, d1);
1537  Value offset = memrefIndices[dim];
1538  memrefIndices[dim] =
1539  affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1540  return dim;
1541  }
1542 
1543  assert(xferOp.isBroadcastDim(0) &&
1544  "Expected AffineDimExpr or AffineConstantExpr");
1545  return std::nullopt;
1546 }
1547 
1548 /// Codegen strategy for TransferOp1dConversion, depending on the
1549 /// operation.
1550 template <typename OpTy>
1551 struct Strategy1d;
1552 
1553 /// Codegen strategy for TransferReadOp.
1554 template <>
1555 struct Strategy1d<TransferReadOp> {
1556  static void generateForLoopBody(OpBuilder &b, Location loc,
1557  TransferReadOp xferOp, Value iv,
1558  ValueRange loopState) {
1559  SmallVector<Value, 8> indices;
1560  auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1561  auto vec = loopState[0];
1562 
1563  // In case of out-of-bounds access, leave `vec` as is (was initialized with
1564  // padding value).
1565  auto nextVec = generateInBoundsCheck(
1566  b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1567  /*inBoundsCase=*/
1568  [&](OpBuilder &b, Location loc) {
1569  Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
1570  return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1571  },
1572  /*outOfBoundsCase=*/
1573  [&](OpBuilder & /*b*/, Location loc) { return vec; });
1574  b.create<scf::YieldOp>(loc, nextVec);
1575  }
1576 
1577  static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1578  // Inititalize vector with padding value.
1579  Location loc = xferOp.getLoc();
1580  return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1581  xferOp.getPadding());
1582  }
1583 };
1584 
1585 /// Codegen strategy for TransferWriteOp.
1586 template <>
1587 struct Strategy1d<TransferWriteOp> {
1588  static void generateForLoopBody(OpBuilder &b, Location loc,
1589  TransferWriteOp xferOp, Value iv,
1590  ValueRange /*loopState*/) {
1591  SmallVector<Value, 8> indices;
1592  auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1593 
1594  // Nothing to do in case of out-of-bounds access.
1595  generateInBoundsCheck(
1596  b, xferOp, iv, dim,
1597  /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1598  auto val =
1599  b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1600  b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
1601  });
1602  b.create<scf::YieldOp>(loc);
1603  }
1604 
1605  static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1606  return Value();
1607  }
1608 };
1609 
1610 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1611 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1612 /// vector load/stores due to non-unit strides or broadcasts:
1613 ///
1614 /// * Transfer dimension is not the last memref dimension
1615 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1616 /// * Memref has a layout map with non-unit stride on the last dimension
1617 ///
1618 /// This pattern generates IR as follows:
1619 ///
1620 /// 1. Generate a for loop iterating over each vector element.
1621 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1622 /// depending on OpTy.
1623 ///
1624 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1625 /// can be generated instead of TransferOp1dConversion. Add such a pattern
1626 /// to ConvertVectorToLLVM.
1627 ///
1628 /// E.g.:
1629 /// ```
1630 /// vector.transfer_write %vec, %A[%a, %b]
1631 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1632 /// : vector<9xf32>, memref<?x?xf32>
1633 /// ```
1634 /// Is rewritten to approximately the following pseudo-IR:
1635 /// ```
1636 /// for i = 0 to 9 {
1637 /// %t = vector.extractelement %vec[i] : vector<9xf32>
1638 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1639 /// }
1640 /// ```
1641 template <typename OpTy>
1642 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1643  using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1644 
1645  LogicalResult matchAndRewrite(OpTy xferOp,
1646  PatternRewriter &rewriter) const override {
1647  // TODO: support 0-d corner case.
1648  if (xferOp.getTransferRank() == 0)
1649  return failure();
1650  auto map = xferOp.getPermutationMap();
1651  auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1652 
1653  if (!memRefType)
1654  return failure();
1655  if (xferOp.getVectorType().getRank() != 1)
1656  return failure();
1657  if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
1658  return failure(); // Handled by ConvertVectorToLLVM
1659 
1660  // Loop bounds, step, state...
1661  Location loc = xferOp.getLoc();
1662  auto vecType = xferOp.getVectorType();
1663  auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1664  Value ub =
1665  rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1666  if (vecType.isScalable()) {
1667  Value vscale =
1668  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
1669  ub = rewriter.create<arith::MulIOp>(loc, ub, vscale);
1670  }
1671  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1672  auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1673 
1674  // Generate for loop.
1675  rewriter.replaceOpWithNewOp<scf::ForOp>(
1676  xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1677  [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1678  Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1679  });
1680 
1681  return success();
1682  }
1683 };
1684 
1685 } // namespace lowering_1_d
1686 } // namespace
1687 
1690  if (options.unroll) {
1691  patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1692  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1693  patterns.getContext(), options);
1694  } else {
1695  patterns.add<lowering_n_d::PrepareTransferReadConversion,
1696  lowering_n_d::PrepareTransferWriteConversion,
1697  lowering_n_d::TransferOpConversion<TransferReadOp>,
1698  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1699  patterns.getContext(), options);
1700  }
1701  if (options.lowerScalable) {
1702  patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1703  patterns.getContext(), options);
1704  }
1705  if (options.targetRank == 1) {
1706  patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1707  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1708  patterns.getContext(), options);
1709  }
1710  patterns.add<lowering_n_d::DecomposePrintOpConversion>(patterns.getContext(),
1711  options);
1712 }
1713 
1714 namespace {
1715 
1716 struct ConvertVectorToSCFPass
1717  : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1718  ConvertVectorToSCFPass() = default;
1719  ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1720  this->fullUnroll = options.unroll;
1721  this->targetRank = options.targetRank;
1722  this->lowerTensors = options.lowerTensors;
1723  this->lowerScalable = options.lowerScalable;
1724  }
1725 
1726  void runOnOperation() override {
1728  options.unroll = fullUnroll;
1729  options.targetRank = targetRank;
1730  options.lowerTensors = lowerTensors;
1731  options.lowerScalable = lowerScalable;
1732 
1733  // Lower permutation maps first.
1734  RewritePatternSet lowerTransferPatterns(&getContext());
1736  lowerTransferPatterns);
1737  (void)applyPatternsGreedily(getOperation(),
1738  std::move(lowerTransferPatterns));
1739 
1742  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
1743  }
1744 };
1745 
1746 } // namespace
1747 
1748 std::unique_ptr<Pass>
1750  return std::make_unique<ConvertVectorToSCFPass>(options);
1751 }
MLIR_CRUNNERUTILS_EXPORT void printClose()
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition: Unit.cpp:19
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Definition: VectorToGPU.cpp:57
static Operation * getAutomaticAllocationScope(Operation *op)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:409
A trait of region holding operations that define a new scope for automatic allocations,...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
Block & front()
Definition: Region.h:65
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:295
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1175
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:41
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:370
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Definition: VectorUtils.h:113
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Create a pass to convert a subset of vector ops to SCF.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...
Definition: VectorToSCF.h:52