MLIR  22.0.0git
VectorToSCF.cpp
Go to the documentation of this file.
1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <numeric>
14 #include <optional>
15 
17 
26 #include "mlir/IR/Builders.h"
27 #include "mlir/Pass/Pass.h"
29 
30 namespace mlir {
31 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
32 #include "mlir/Conversion/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 using vector::TransferReadOp;
37 using vector::TransferWriteOp;
38 
39 namespace {
40 
41 /// Attribute name used for labeling transfer ops during progressive lowering.
42 static const char kPassLabel[] = "__vector_to_scf_lowering__";
43 
44 /// Return true if this transfer op operates on a source tensor.
45 static bool isTensorOp(VectorTransferOpInterface xferOp) {
46  if (isa<RankedTensorType>(xferOp.getShapedType())) {
47  if (isa<vector::TransferWriteOp>(xferOp)) {
48  // TransferWriteOps on tensors have a result.
49  assert(xferOp->getNumResults() > 0);
50  }
51  return true;
52  }
53  return false;
54 }
55 
56 /// Patterns that inherit from this struct have access to
57 /// VectorTransferToSCFOptions.
58 template <typename OpTy>
59 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
60  explicit VectorToSCFPattern(MLIRContext *context,
62  : OpRewritePattern<OpTy>(context), options(opt) {}
63 
64  LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
65  PatternRewriter &rewriter) const {
66  if (isTensorOp(xferOp) && !options.lowerTensors) {
67  return rewriter.notifyMatchFailure(
68  xferOp, "lowering tensor transfers is disabled");
69  }
70  return success();
71  }
72 
74 };
75 
76 /// Given a vector transfer op, calculate which dimension of the `source`
77 /// memref should be unpacked in the next application of TransferOpConversion.
78 /// A return value of std::nullopt indicates a broadcast.
79 template <typename OpTy>
80 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
81  // TODO: support 0-d corner case.
82  assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
83  auto map = xferOp.getPermutationMap();
84  if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
85  return expr.getPosition();
86  }
87  assert(xferOp.isBroadcastDim(0) &&
88  "Expected AffineDimExpr or AffineConstantExpr");
89  return std::nullopt;
90 }
91 
92 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
93 /// map is identical to the current permutation map, but the first result is
94 /// omitted.
95 template <typename OpTy>
96 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
97  // TODO: support 0-d corner case.
98  assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
99  auto map = xferOp.getPermutationMap();
100  return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
101  b.getContext());
102 }
103 
104 /// Calculate the indices for the new vector transfer op.
105 ///
106 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
107 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
108 /// ^^^^^^
109 /// `iv` is the iteration variable of the (new) surrounding loop.
110 template <typename OpTy>
111 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
112  SmallVector<Value, 8> &indices) {
113  typename OpTy::Adaptor adaptor(xferOp);
114  // Corresponding memref dim of the vector dim that is unpacked.
115  auto dim = unpackedDim(xferOp);
116  auto prevIndices = adaptor.getIndices();
117  indices.append(prevIndices.begin(), prevIndices.end());
118 
119  Location loc = xferOp.getLoc();
120  bool isBroadcast = !dim.has_value();
121  if (!isBroadcast) {
122  AffineExpr d0, d1;
123  bindDims(xferOp.getContext(), d0, d1);
124  Value offset = adaptor.getIndices()[*dim];
125  indices[*dim] =
126  affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
127  }
128 }
129 
130 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
131  Value value) {
132  if (hasRetVal) {
133  assert(value && "Expected non-empty value");
134  scf::YieldOp::create(b, loc, value);
135  } else {
136  scf::YieldOp::create(b, loc);
137  }
138 }
139 
140 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
141 /// is set to true. No such check is generated under following circumstances:
142 /// * xferOp does not have a mask.
143 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
144 /// computed and attached to the new transfer op in the pattern.)
145 /// * The to-be-unpacked dim of xferOp is a broadcast.
146 template <typename OpTy>
147 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
148  if (!xferOp.getMask())
149  return Value();
150  if (xferOp.getMaskType().getRank() != 1)
151  return Value();
152  if (xferOp.isBroadcastDim(0))
153  return Value();
154 
155  Location loc = xferOp.getLoc();
156  return vector::ExtractOp::create(b, loc, xferOp.getMask(), iv);
157 }
158 
159 /// Helper function TransferOpConversion and TransferOp1dConversion.
160 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
161 /// specified dimension `dim` with the loop iteration variable `iv`.
162 /// E.g., when unpacking dimension 0 from:
163 /// ```
164 /// %vec = vector.transfer_read %A[%a, %b] %cst
165 /// : vector<5x4xf32>, memref<?x?xf32>
166 /// ```
167 /// An if check similar to this will be generated inside the loop:
168 /// ```
169 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
170 /// if (%a + iv < %d) {
171 /// (in-bounds case)
172 /// } else {
173 /// (out-of-bounds case)
174 /// }
175 /// ```
176 ///
177 /// If the transfer is 1D and has a mask, this function generates a more complex
178 /// check also accounts for potentially masked out elements.
179 ///
180 /// This function variant returns the value returned by `inBoundsCase` or
181 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
182 /// `resultTypes`.
183 template <typename OpTy>
184 static Value generateInBoundsCheck(
185  OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
186  TypeRange resultTypes,
187  function_ref<Value(OpBuilder &, Location)> inBoundsCase,
188  function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
189  bool hasRetVal = !resultTypes.empty();
190  Value cond; // Condition to be built...
191 
192  // Condition check 1: Access in-bounds?
193  bool isBroadcast = !dim; // No in-bounds check for broadcasts.
194  Location loc = xferOp.getLoc();
195  ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
196  if (!xferOp.isDimInBounds(0) && !isBroadcast) {
197  Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.getBase(), *dim);
198  AffineExpr d0, d1;
199  bindDims(xferOp.getContext(), d0, d1);
200  Value base = xferOp.getIndices()[*dim];
201  Value memrefIdx =
202  affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
203  cond = arith::CmpIOp::create(lb, arith::CmpIPredicate::sgt, memrefDim,
204  memrefIdx);
205  }
206 
207  // Condition check 2: Masked in?
208  if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
209  if (cond)
210  cond = arith::AndIOp::create(lb, cond, maskCond);
211  else
212  cond = maskCond;
213  }
214 
215  // If the condition is non-empty, generate an SCF::IfOp.
216  if (cond) {
217  auto check = scf::IfOp::create(
218  lb, cond,
219  /*thenBuilder=*/
220  [&](OpBuilder &b, Location loc) {
221  maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
222  },
223  /*elseBuilder=*/
224  [&](OpBuilder &b, Location loc) {
225  if (outOfBoundsCase) {
226  maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
227  } else {
228  scf::YieldOp::create(b, loc);
229  }
230  });
231 
232  return hasRetVal ? check.getResult(0) : Value();
233  }
234 
235  // Condition is empty, no need for an SCF::IfOp.
236  return inBoundsCase(b, loc);
237 }
238 
239 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
240 /// a return value. Consequently, this function does not have a return value.
241 template <typename OpTy>
242 static void generateInBoundsCheck(
243  OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
244  function_ref<void(OpBuilder &, Location)> inBoundsCase,
245  function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
246  generateInBoundsCheck(
247  b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
248  /*inBoundsCase=*/
249  [&](OpBuilder &b, Location loc) {
250  inBoundsCase(b, loc);
251  return Value();
252  },
253  /*outOfBoundsCase=*/
254  [&](OpBuilder &b, Location loc) {
255  if (outOfBoundsCase)
256  outOfBoundsCase(b, loc);
257  return Value();
258  });
259 }
260 
261 /// Given an ArrayAttr, return a copy where the first element is dropped.
262 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
263  if (!attr)
264  return attr;
265  return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
266 }
267 
268 /// Add the pass label to a vector transfer op if its rank is not the target
269 /// rank.
270 template <typename OpTy>
271 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
272  unsigned targetRank) {
273  if (newXferOp.getVectorType().getRank() > targetRank)
274  newXferOp->setAttr(kPassLabel, b.getUnitAttr());
275 }
276 
277 namespace lowering_n_d {
278 
279 /// Helper data structure for data and mask buffers.
280 struct BufferAllocs {
281  Value dataBuffer;
282  Value maskBuffer;
283 };
284 
285 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
287  Operation *scope =
289  assert(scope && "Expected op to be inside automatic allocation scope");
290  return scope;
291 }
292 
293 /// Allocate temporary buffers for data (vector) and mask (if present).
294 template <typename OpTy>
295 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
296  Location loc = xferOp.getLoc();
297  OpBuilder::InsertionGuard guard(b);
298  Operation *scope = getAutomaticAllocationScope(xferOp);
299  assert(scope->getNumRegions() == 1 &&
300  "AutomaticAllocationScope with >1 regions");
301  b.setInsertionPointToStart(&scope->getRegion(0).front());
302 
303  BufferAllocs result;
304  auto bufferType = MemRefType::get({}, xferOp.getVectorType());
305  result.dataBuffer = memref::AllocaOp::create(b, loc, bufferType);
306 
307  if (xferOp.getMask()) {
308  auto maskType = MemRefType::get({}, xferOp.getMask().getType());
309  auto maskBuffer = memref::AllocaOp::create(b, loc, maskType);
310  b.setInsertionPoint(xferOp);
311  memref::StoreOp::create(b, loc, xferOp.getMask(), maskBuffer);
312  result.maskBuffer =
313  memref::LoadOp::create(b, loc, maskBuffer, ValueRange());
314  }
315 
316  return result;
317 }
318 
319 /// Given a MemRefType with VectorType element type, unpack one dimension from
320 /// the VectorType into the MemRefType.
321 ///
322 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
323 static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
324  auto vectorType = dyn_cast<VectorType>(type.getElementType());
325  // Vectors with leading scalable dims are not supported.
326  // It may be possible to support these in future by using dynamic memref dims.
327  if (vectorType.getScalableDims().front())
328  return failure();
329  auto memrefShape = type.getShape();
330  SmallVector<int64_t, 8> newMemrefShape;
331  newMemrefShape.append(memrefShape.begin(), memrefShape.end());
332  newMemrefShape.push_back(vectorType.getDimSize(0));
333  return MemRefType::get(newMemrefShape,
334  VectorType::Builder(vectorType).dropDim(0));
335 }
336 
337 /// Given a transfer op, find the memref from which the mask is loaded. This
338 /// is similar to Strategy<TransferWriteOp>::getBuffer.
339 template <typename OpTy>
340 static Value getMaskBuffer(OpTy xferOp) {
341  assert(xferOp.getMask() && "Expected that transfer op has mask");
342  auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
343  assert(loadOp && "Expected transfer op mask produced by LoadOp");
344  return loadOp.getMemRef();
345 }
346 
347 /// Codegen strategy, depending on the operation.
348 template <typename OpTy>
349 struct Strategy;
350 
351 /// Code strategy for vector TransferReadOp.
352 template <>
353 struct Strategy<TransferReadOp> {
354  /// Find the StoreOp that is used for writing the current TransferReadOp's
355  /// result to the temporary buffer allocation.
356  static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
357  assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
358  auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
359  assert(storeOp && "Expected TransferReadOp result used by StoreOp");
360  return storeOp;
361  }
362 
363  /// Find the temporary buffer allocation. All labeled TransferReadOps are
364  /// used like this, where %buf is either the buffer allocation or a type cast
365  /// of the buffer allocation:
366  /// ```
367  /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
368  /// memref.store %vec, %buf[...] ...
369  /// ```
370  static Value getBuffer(TransferReadOp xferOp) {
371  return getStoreOp(xferOp).getMemRef();
372  }
373 
374  /// Retrieve the indices of the current StoreOp that stores into the buffer.
375  static void getBufferIndices(TransferReadOp xferOp,
376  SmallVector<Value, 8> &indices) {
377  auto storeOp = getStoreOp(xferOp);
378  auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
379  indices.append(prevIndices.begin(), prevIndices.end());
380  }
381 
382  /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
383  /// accesses on the to-be-unpacked dimension.
384  ///
385  /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
386  /// variable `iv`.
387  /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
388  ///
389  /// E.g.:
390  /// ```
391  /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
392  /// : memref<?x?x?xf32>, vector<4x3xf32>
393  /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
394  /// ```
395  /// Is rewritten to:
396  /// ```
397  /// %casted = vector.type_cast %buf
398  /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
399  /// for %j = 0 to 4 {
400  /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
401  /// : memref<?x?x?xf32>, vector<3xf32>
402  /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
403  /// }
404  /// ```
405  ///
406  /// Note: The loop and type cast are generated in TransferOpConversion.
407  /// The original TransferReadOp and store op are deleted in `cleanup`.
408  /// Note: The `mask` operand is set in TransferOpConversion.
409  static TransferReadOp rewriteOp(OpBuilder &b,
411  TransferReadOp xferOp, Value buffer, Value iv,
412  ValueRange /*loopState*/) {
413  SmallVector<Value, 8> storeIndices;
414  getBufferIndices(xferOp, storeIndices);
415  storeIndices.push_back(iv);
416 
417  SmallVector<Value, 8> xferIndices;
418  getXferIndices(b, xferOp, iv, xferIndices);
419 
420  Location loc = xferOp.getLoc();
421  auto bufferType = dyn_cast<ShapedType>(buffer.getType());
422  auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
423  auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
424  auto newXferOp = vector::TransferReadOp::create(
425  b, loc, vecType, xferOp.getBase(), xferIndices,
426  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
427  xferOp.getPadding(), Value(), inBoundsAttr);
428 
429  maybeApplyPassLabel(b, newXferOp, options.targetRank);
430 
431  memref::StoreOp::create(b, loc, newXferOp.getVector(), buffer,
432  storeIndices);
433  return newXferOp;
434  }
435 
436  /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
437  /// padding value to the temporary buffer.
438  static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
439  Value buffer, Value iv,
440  ValueRange /*loopState*/) {
441  SmallVector<Value, 8> storeIndices;
442  getBufferIndices(xferOp, storeIndices);
443  storeIndices.push_back(iv);
444 
445  Location loc = xferOp.getLoc();
446  auto bufferType = dyn_cast<ShapedType>(buffer.getType());
447  auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
448  auto vec =
449  vector::BroadcastOp::create(b, loc, vecType, xferOp.getPadding());
450  memref::StoreOp::create(b, loc, vec, buffer, storeIndices);
451 
452  return Value();
453  }
454 
455  /// Cleanup after rewriting the op.
456  static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
457  scf::ForOp /*forOp*/) {
458  rewriter.eraseOp(getStoreOp(xferOp));
459  rewriter.eraseOp(xferOp);
460  }
461 
462  /// Return the initial loop state for the generated scf.for loop.
463  static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
464 };
465 
466 /// Codegen strategy for vector TransferWriteOp.
467 template <>
468 struct Strategy<TransferWriteOp> {
469  /// Find the temporary buffer allocation. All labeled TransferWriteOps are
470  /// used like this, where %buf is either the buffer allocation or a type cast
471  /// of the buffer allocation:
472  /// ```
473  /// %vec = memref.load %buf[...] ...
474  /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
475  /// ```
476  static Value getBuffer(TransferWriteOp xferOp) {
477  auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
478  assert(loadOp && "Expected transfer op vector produced by LoadOp");
479  return loadOp.getMemRef();
480  }
481 
482  /// Retrieve the indices of the current LoadOp that loads from the buffer.
483  static void getBufferIndices(TransferWriteOp xferOp,
484  SmallVector<Value, 8> &indices) {
485  auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
486  auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
487  indices.append(prevIndices.begin(), prevIndices.end());
488  }
489 
490  /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
491  /// accesses on the to-be-unpacked dimension.
492  ///
493  /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
494  /// using the loop iteration variable `iv`.
495  /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
496  /// to memory.
497  ///
498  /// Note: For more details, see comments on Strategy<TransferReadOp>.
499  static TransferWriteOp rewriteOp(OpBuilder &b,
501  TransferWriteOp xferOp, Value buffer,
502  Value iv, ValueRange loopState) {
503  SmallVector<Value, 8> loadIndices;
504  getBufferIndices(xferOp, loadIndices);
505  loadIndices.push_back(iv);
506 
507  SmallVector<Value, 8> xferIndices;
508  getXferIndices(b, xferOp, iv, xferIndices);
509 
510  Location loc = xferOp.getLoc();
511  auto vec = memref::LoadOp::create(b, loc, buffer, loadIndices);
512  auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
513  auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
514  Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
515  auto newXferOp = vector::TransferWriteOp::create(
516  b, loc, type, vec, source, xferIndices,
517  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
518  inBoundsAttr);
519 
520  maybeApplyPassLabel(b, newXferOp, options.targetRank);
521 
522  return newXferOp;
523  }
524 
525  /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
526  static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
527  Value buffer, Value iv,
528  ValueRange loopState) {
529  return isTensorOp(xferOp) ? loopState[0] : Value();
530  }
531 
532  /// Cleanup after rewriting the op.
533  static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
534  scf::ForOp forOp) {
535  if (isTensorOp(xferOp)) {
536  assert(forOp->getNumResults() == 1 && "Expected one for loop result");
537  rewriter.replaceOp(xferOp, forOp->getResult(0));
538  } else {
539  rewriter.eraseOp(xferOp);
540  }
541  }
542 
543  /// Return the initial loop state for the generated scf.for loop.
544  static Value initialLoopState(TransferWriteOp xferOp) {
545  return isTensorOp(xferOp) ? xferOp.getBase() : Value();
546  }
547 };
548 
549 template <typename OpTy>
550 static LogicalResult checkPrepareXferOp(OpTy xferOp, PatternRewriter &rewriter,
552  if (xferOp->hasAttr(kPassLabel))
553  return rewriter.notifyMatchFailure(
554  xferOp, "kPassLabel is present (vector-to-scf lowering in progress)");
555  if (xferOp.getVectorType().getRank() <= options.targetRank)
556  return rewriter.notifyMatchFailure(
557  xferOp, "xferOp vector rank <= transformation target rank");
558  if (xferOp.getVectorType().getScalableDims().front())
559  return rewriter.notifyMatchFailure(
560  xferOp, "Unpacking of the leading dimension into the memref is not yet "
561  "supported for scalable dims");
562  if (isTensorOp(xferOp) && !options.lowerTensors)
563  return rewriter.notifyMatchFailure(
564  xferOp, "Unpacking for tensors has been disabled.");
565  if (xferOp.getVectorType().getElementType() !=
566  xferOp.getShapedType().getElementType())
567  return rewriter.notifyMatchFailure(
568  xferOp, "Mismatching source and destination element types.");
569 
570  return success();
571 }
572 
573 /// Prepare a TransferReadOp for progressive lowering.
574 ///
575 /// 1. Allocate a temporary buffer.
576 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
577 /// 3. Store the result of the TransferReadOp into the temporary buffer.
578 /// 4. Load the result from the temporary buffer and replace all uses of the
579 /// original TransferReadOp with this load.
580 ///
581 /// E.g.:
582 /// ```
583 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
584 /// : vector<5x4xf32>, memref<?x?x?xf32>
585 /// ```
586 /// is rewritten to:
587 /// ```
588 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
589 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
590 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
591 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
592 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
593 /// ```
594 ///
595 /// Note: A second temporary buffer may be allocated for the `mask` operand.
596 struct PrepareTransferReadConversion
597  : public VectorToSCFPattern<TransferReadOp> {
598  using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
599 
600  LogicalResult matchAndRewrite(TransferReadOp xferOp,
601  PatternRewriter &rewriter) const override {
602  if (checkPrepareXferOp(xferOp, rewriter, options).failed())
603  return rewriter.notifyMatchFailure(
604  xferOp, "checkPrepareXferOp conditions not met!");
605 
606  auto buffers = allocBuffers(rewriter, xferOp);
607  auto *newXfer = rewriter.clone(*xferOp.getOperation());
608  newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
609  if (xferOp.getMask()) {
610  dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
611  buffers.maskBuffer);
612  }
613 
614  Location loc = xferOp.getLoc();
615  memref::StoreOp::create(rewriter, loc, newXfer->getResult(0),
616  buffers.dataBuffer);
617  rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
618 
619  return success();
620  }
621 };
622 
623 /// Prepare a TransferWriteOp for progressive lowering.
624 ///
625 /// 1. Allocate a temporary buffer.
626 /// 2. Store the vector into the buffer.
627 /// 3. Load the vector from the buffer again.
628 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
629 /// marking it eligible for progressive lowering via TransferOpConversion.
630 ///
631 /// E.g.:
632 /// ```
633 /// vector.transfer_write %vec, %A[%a, %b, %c]
634 /// : vector<5x4xf32>, memref<?x?x?xf32>
635 /// ```
636 /// is rewritten to:
637 /// ```
638 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
639 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
640 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
641 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
642 /// : vector<5x4xf32>, memref<?x?x?xf32>
643 /// ```
644 ///
645 /// Note: A second temporary buffer may be allocated for the `mask` operand.
646 struct PrepareTransferWriteConversion
647  : public VectorToSCFPattern<TransferWriteOp> {
648  using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
649 
650  LogicalResult matchAndRewrite(TransferWriteOp xferOp,
651  PatternRewriter &rewriter) const override {
652  if (checkPrepareXferOp(xferOp, rewriter, options).failed())
653  return rewriter.notifyMatchFailure(
654  xferOp, "checkPrepareXferOp conditions not met!");
655 
656  Location loc = xferOp.getLoc();
657  auto buffers = allocBuffers(rewriter, xferOp);
658  memref::StoreOp::create(rewriter, loc, xferOp.getVector(),
659  buffers.dataBuffer);
660  auto loadedVec = memref::LoadOp::create(rewriter, loc, buffers.dataBuffer);
661  rewriter.modifyOpInPlace(xferOp, [&]() {
662  xferOp.getValueToStoreMutable().assign(loadedVec);
663  xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
664  });
665 
666  if (xferOp.getMask()) {
667  rewriter.modifyOpInPlace(xferOp, [&]() {
668  xferOp.getMaskMutable().assign(buffers.maskBuffer);
669  });
670  }
671 
672  return success();
673  }
674 };
675 
676 /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows
677 /// printing both 1D scalable vectors and n-D fixed size vectors.
678 ///
679 /// E.g.:
680 /// ```
681 /// vector.print %v : vector<[4]xi32>
682 /// ```
683 /// is rewritten to:
684 /// ```
685 /// %c0 = arith.constant 0 : index
686 /// %c4 = arith.constant 4 : index
687 /// %c1 = arith.constant 1 : index
688 /// %vscale = vector.vscale
689 /// %length = arith.muli %vscale, %c4 : index
690 /// %lastIndex = arith.subi %length, %c1 : index
691 /// vector.print punctuation <open>
692 /// scf.for %i = %c0 to %length step %c1 {
693 /// %el = vector.extract %v[%i] : i32 from vector<[4]xi32>
694 /// vector.print %el : i32 punctuation <no_punctuation>
695 /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
696 /// scf.if %notLastIndex {
697 /// vector.print punctuation <comma>
698 /// }
699 /// }
700 /// vector.print punctuation <close>
701 /// vector.print
702 /// ```
703 struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
704  using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
705  LogicalResult matchAndRewrite(vector::PrintOp printOp,
706  PatternRewriter &rewriter) const override {
707  if (!printOp.getSource())
708  return failure();
709 
710  VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
711  if (!vectorType)
712  return failure();
713 
714  // Currently >= 2D scalable vectors are not supported.
715  // These can't be lowered to LLVM (as LLVM does not support scalable vectors
716  // of scalable vectors), and due to limitations of current ops can't be
717  // indexed with SSA values or flattened. This may change after
718  // https://reviews.llvm.org/D155034, though there still needs to be a path
719  // for lowering to LLVM.
720  if (vectorType.getRank() > 1 && vectorType.isScalable())
721  return failure();
722 
723  auto loc = printOp.getLoc();
724  auto value = printOp.getSource();
725 
726  if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
727  // Oddly sized integers are (somewhat) buggy on a lot of backends, so to
728  // avoid issues extend them to a more standard size.
729  // https://github.com/llvm/llvm-project/issues/30613
730  auto width = intTy.getWidth();
731  auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1);
732  auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
733  intTy.getSignedness());
734  // arith can only take signless integers, so we must cast back and forth.
735  auto signlessSourceVectorType =
736  vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
737  auto signlessTargetVectorType =
738  vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
739  auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
740  value = vector::BitCastOp::create(rewriter, loc, signlessSourceVectorType,
741  value);
742  if (value.getType() != signlessTargetVectorType) {
743  if (width == 1 || intTy.isUnsigned())
744  value = arith::ExtUIOp::create(rewriter, loc,
745  signlessTargetVectorType, value);
746  else
747  value = arith::ExtSIOp::create(rewriter, loc,
748  signlessTargetVectorType, value);
749  }
750  value = vector::BitCastOp::create(rewriter, loc, targetVectorType, value);
751  vectorType = targetVectorType;
752  }
753 
754  auto scalableDimensions = vectorType.getScalableDims();
755  auto shape = vectorType.getShape();
756  constexpr int64_t singletonShape[] = {1};
757  if (vectorType.getRank() == 0)
758  shape = singletonShape;
759 
760  if (vectorType.getRank() != 1) {
761  // Flatten n-D vectors to 1D. This is done to allow indexing with a
762  // non-constant value.
763  auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
764  std::multiplies<int64_t>());
765  auto flatVectorType =
766  VectorType::get({flatLength}, vectorType.getElementType());
767  value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value);
768  }
769 
770  vector::PrintOp firstClose;
771  SmallVector<Value, 8> loopIndices;
772  for (unsigned d = 0; d < shape.size(); d++) {
773  // Setup loop bounds and step.
774  Value lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
775  Value upperBound =
776  arith::ConstantIndexOp::create(rewriter, loc, shape[d]);
777  Value step = arith::ConstantIndexOp::create(rewriter, loc, 1);
778  if (!scalableDimensions.empty() && scalableDimensions[d]) {
779  auto vscale = vector::VectorScaleOp::create(rewriter, loc,
780  rewriter.getIndexType());
781  upperBound = arith::MulIOp::create(rewriter, loc, upperBound, vscale);
782  }
783  auto lastIndex = arith::SubIOp::create(rewriter, loc, upperBound, step);
784 
785  // Create a loop to print the elements surrounded by parentheses.
786  vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
787  auto loop =
788  scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
789  auto printClose = vector::PrintOp::create(
790  rewriter, loc, vector::PrintPunctuation::Close);
791  if (!firstClose)
792  firstClose = printClose;
793 
794  auto loopIdx = loop.getInductionVar();
795  loopIndices.push_back(loopIdx);
796 
797  // Print a comma after all but the last element.
798  rewriter.setInsertionPointToStart(loop.getBody());
799  auto notLastIndex = arith::CmpIOp::create(
800  rewriter, loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
801  scf::IfOp::create(rewriter, loc, notLastIndex,
802  [&](OpBuilder &builder, Location loc) {
803  vector::PrintOp::create(
804  builder, loc, vector::PrintPunctuation::Comma);
805  scf::YieldOp::create(builder, loc);
806  });
807 
808  rewriter.setInsertionPointToStart(loop.getBody());
809  }
810 
811  // Compute the flattened index.
812  // Note: For the > rank 1 vectors this assumes non-scalable.
813  Value flatIndex;
814  auto currentStride = 1;
815  for (int d = shape.size() - 1; d >= 0; d--) {
816  auto stride =
817  arith::ConstantIndexOp::create(rewriter, loc, currentStride);
818  auto index = arith::MulIOp::create(rewriter, loc, stride, loopIndices[d]);
819  if (flatIndex)
820  flatIndex = arith::AddIOp::create(rewriter, loc, flatIndex, index);
821  else
822  flatIndex = index;
823  currentStride *= shape[d];
824  }
825 
826  // Print the scalar elements in the inner most loop.
827  auto element = vector::ExtractOp::create(rewriter, loc, value, flatIndex);
828  vector::PrintOp::create(rewriter, loc, element,
829  vector::PrintPunctuation::NoPunctuation);
830 
831  rewriter.setInsertionPointAfter(firstClose);
832  vector::PrintOp::create(rewriter, loc, printOp.getPunctuation());
833  rewriter.eraseOp(printOp);
834  return success();
835  }
836 
837  static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
838  return IntegerType::get(intTy.getContext(), intTy.getWidth(),
839  IntegerType::Signless);
840  };
841 };
842 
843 /// Progressive lowering of vector transfer ops: Unpack one dimension.
844 ///
845 /// 1. Unpack one dimension from the current buffer type and cast the buffer
846 /// to that new type. E.g.:
847 /// ```
848 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
849 /// vector.transfer_write %vec ...
850 /// ```
851 /// The following cast is generated:
852 /// ```
853 /// %casted = vector.type_cast %0
854 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
855 /// ```
856 /// 2. Generate a for loop and rewrite the transfer op according to the
857 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
858 /// out-of-bounds, generate an if-check and handle both cases separately.
859 /// 3. Clean up according to the corresponding Strategy<OpTy>.
860 ///
861 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
862 /// source (as opposed to a memref source), then each iteration of the generated
863 /// scf.for loop yields the new tensor value. E.g.:
864 /// ```
865 /// %result = scf.for i = 0 to 5 {
866 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
867 /// %1 = vector.transfer_write %0, %source[...]
868 /// : vector<4x3xf32>, tensor<5x4x3xf32>
869 /// scf.yield %1 : tensor<5x4x3xf32>
870 /// }
871 /// ```
872 template <typename OpTy>
873 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
874  using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
875 
876  void initialize() {
877  // This pattern recursively unpacks one dimension at a time. The recursion
878  // bounded as the rank is strictly decreasing.
879  this->setHasBoundedRewriteRecursion();
880  }
881 
882  static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
883  SmallVectorImpl<Value> &loadIndices,
884  Value iv) {
885  assert(xferOp.getMask() && "Expected transfer op to have mask");
886 
887  // Add load indices from the previous iteration.
888  // The mask buffer depends on the permutation map, which makes determining
889  // the indices quite complex, so this is why we need to "look back" to the
890  // previous iteration to find the right indices.
891  Value maskBuffer = getMaskBuffer(xferOp);
892  for (Operation *user : maskBuffer.getUsers()) {
893  // If there is no previous load op, then the indices are empty.
894  if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
895  Operation::operand_range prevIndices = loadOp.getIndices();
896  loadIndices.append(prevIndices.begin(), prevIndices.end());
897  break;
898  }
899  }
900 
901  // In case of broadcast: Use same indices to load from memref
902  // as before.
903  if (!xferOp.isBroadcastDim(0))
904  loadIndices.push_back(iv);
905  }
906 
907  LogicalResult matchAndRewrite(OpTy xferOp,
908  PatternRewriter &rewriter) const override {
909  if (!xferOp->hasAttr(kPassLabel))
910  return rewriter.notifyMatchFailure(
911  xferOp, "kPassLabel is present (progressing lowering in progress)");
912 
913  // Find and cast data buffer. How the buffer can be found depends on OpTy.
914  ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
915  Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
916  auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
917  FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
918  if (failed(castedDataType))
919  return rewriter.notifyMatchFailure(xferOp,
920  "Failed to unpack one vector dim.");
921 
922  auto castedDataBuffer =
923  vector::TypeCastOp::create(locB, *castedDataType, dataBuffer);
924 
925  // If the xferOp has a mask: Find and cast mask buffer.
926  Value castedMaskBuffer;
927  if (xferOp.getMask()) {
928  Value maskBuffer = getMaskBuffer(xferOp);
929  if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
930  // Do not unpack a dimension of the mask, if:
931  // * To-be-unpacked transfer op dimension is a broadcast.
932  // * Mask is 1D, i.e., the mask cannot be further unpacked.
933  // (That means that all remaining dimensions of the transfer op must
934  // be broadcasted.)
935  castedMaskBuffer = maskBuffer;
936  } else {
937  // It's safe to assume the mask buffer can be unpacked if the data
938  // buffer was unpacked.
939  auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
940  MemRefType castedMaskType = *unpackOneDim(maskBufferType);
941  castedMaskBuffer =
942  vector::TypeCastOp::create(locB, castedMaskType, maskBuffer);
943  }
944  }
945 
946  // Loop bounds and step.
947  auto lb = arith::ConstantIndexOp::create(locB, 0);
949  locB, castedDataType->getDimSize(castedDataType->getRank() - 1));
950  auto step = arith::ConstantIndexOp::create(locB, 1);
951  // TransferWriteOps that operate on tensors return the modified tensor and
952  // require a loop state.
953  auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
954 
955  // Generate for loop.
956  auto result = scf::ForOp::create(
957  locB, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
958  [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
959  Type stateType = loopState.empty() ? Type() : loopState[0].getType();
960 
961  auto result = generateInBoundsCheck(
962  b, xferOp, iv, unpackedDim(xferOp),
963  stateType ? TypeRange(stateType) : TypeRange(),
964  /*inBoundsCase=*/
965  [&](OpBuilder &b, Location loc) {
966  // Create new transfer op.
967  OpTy newXfer = Strategy<OpTy>::rewriteOp(
968  b, this->options, xferOp, castedDataBuffer, iv, loopState);
969 
970  // If old transfer op has a mask: Set mask on new transfer op.
971  // Special case: If the mask of the old transfer op is 1D and
972  // the unpacked dim is not a broadcast, no mask is needed on
973  // the new transfer op.
974  if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
975  xferOp.getMaskType().getRank() > 1)) {
976  OpBuilder::InsertionGuard guard(b);
977  b.setInsertionPoint(newXfer); // Insert load before newXfer.
978 
979  SmallVector<Value, 8> loadIndices;
980  getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
981  loadIndices, iv);
982  auto mask = memref::LoadOp::create(b, loc, castedMaskBuffer,
983  loadIndices);
984  rewriter.modifyOpInPlace(newXfer, [&]() {
985  newXfer.getMaskMutable().assign(mask);
986  });
987  }
988 
989  return loopState.empty() ? Value() : newXfer->getResult(0);
990  },
991  /*outOfBoundsCase=*/
992  [&](OpBuilder &b, Location /*loc*/) {
993  return Strategy<OpTy>::handleOutOfBoundsDim(
994  b, xferOp, castedDataBuffer, iv, loopState);
995  });
996 
997  maybeYieldValue(b, loc, !loopState.empty(), result);
998  });
999 
1000  Strategy<OpTy>::cleanup(rewriter, xferOp, result);
1001  return success();
1002  }
1003 };
1004 
1005 /// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
1006 /// and ConstantMaskOp.
1007 template <typename VscaleConstantBuilder>
1008 static FailureOr<SmallVector<OpFoldResult>>
1009 getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1010  if (!mask)
1011  return SmallVector<OpFoldResult>{};
1012  if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
1013  return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
1014  return OpFoldResult(dimSize);
1015  });
1016  }
1017  if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
1018  int dimIdx = 0;
1019  VectorType maskType = constantMask.getVectorType();
1020  auto indexType = IndexType::get(mask.getContext());
1021  return llvm::map_to_vector(
1022  constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1023  // A scalable dim in a constant_mask means vscale x dimSize.
1024  if (maskType.getScalableDims()[dimIdx++])
1025  return OpFoldResult(createVscaleMultiple(dimSize));
1026  return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1027  });
1028  }
1029  return failure();
1030 }
1031 
1032 /// Scalable vector lowering of transfer_write(transpose). This lowering only
1033 /// supports rank 2 (scalable) vectors, but can be used in conjunction with
1034 /// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
1035 /// unrolls until the first scalable dimension.
1036 ///
1037 /// Example:
1038 ///
1039 /// BEFORE:
1040 /// ```mlir
1041 /// %transpose = vector.transpose %vec, [1, 0]
1042 /// : vector<4x[4]xf32> to vector<[4]x4xf32>
1043 /// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
1044 /// : vector<[4]x4xf32>, memref<?x?xf32>
1045 /// ```
1046 ///
1047 /// AFTER:
1048 /// ```mlir
1049 /// %c1 = arith.constant 1 : index
1050 /// %c4 = arith.constant 4 : index
1051 /// %c0 = arith.constant 0 : index
1052 /// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
1053 /// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
1054 /// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
1055 /// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
1056 /// %vscale = vector.vscale
1057 /// %c4_vscale = arith.muli %vscale, %c4 : index
1058 /// scf.for %idx = %c0 to %c4_vscale step %c1 {
1059 /// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1060 /// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1061 /// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1062 /// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1063 /// %slice_i = affine.apply #map(%idx)[%i]
1064 /// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
1065 /// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
1066 /// : vector<4xf32>, memref<?x?xf32>
1067 /// }
1068 /// ```
1069 struct ScalableTransposeTransferWriteConversion
1070  : VectorToSCFPattern<vector::TransferWriteOp> {
1071  using VectorToSCFPattern::VectorToSCFPattern;
1072 
1073  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1074  PatternRewriter &rewriter) const override {
1075  if (failed(checkLowerTensors(writeOp, rewriter)))
1076  return failure();
1077 
1078  VectorType vectorType = writeOp.getVectorType();
1079 
1080  // Note: By comparing the scalable dims to an ArrayRef of length two this
1081  // implicitly checks the rank (is also two).
1082  ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
1083  if (scalableFlags != ArrayRef<bool>{true, false}) {
1084  return rewriter.notifyMatchFailure(
1085  writeOp, "expected vector of the form vector<[N]xMxty>");
1086  }
1087 
1088  auto permutationMap = writeOp.getPermutationMap();
1089  if (!permutationMap.isIdentity()) {
1090  return rewriter.notifyMatchFailure(
1091  writeOp, "non-identity permutations are unsupported (lower first)");
1092  }
1093 
1094  // Note: This pattern is only lowering the leading dimension (to a loop),
1095  // so we only check if the leading dimension is in bounds. The in-bounds
1096  // attribute for the trailing dimension will be propagated.
1097  if (!writeOp.isDimInBounds(0)) {
1098  return rewriter.notifyMatchFailure(
1099  writeOp, "out-of-bounds dims are unsupported (use masking)");
1100  }
1101 
1102  Value vector = writeOp.getVector();
1103  auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
1104  if (!transposeOp ||
1105  transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
1106  return rewriter.notifyMatchFailure(writeOp, "source not transpose");
1107  }
1108 
1109  auto loc = writeOp.getLoc();
1110  auto createVscaleMultiple =
1111  vector::makeVscaleConstantBuilder(rewriter, loc);
1112 
1113  auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1114  if (failed(maskDims)) {
1115  return rewriter.notifyMatchFailure(writeOp,
1116  "failed to resolve mask dims");
1117  }
1118 
1119  int64_t fixedDimSize = vectorType.getDimSize(1);
1120  auto fixedDimOffsets = llvm::seq(fixedDimSize);
1121 
1122  // Extract all slices from the source of the transpose.
1123  auto transposeSource = transposeOp.getVector();
1124  SmallVector<Value> transposeSourceSlices =
1125  llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1126  return vector::ExtractOp::create(rewriter, loc, transposeSource, idx);
1127  });
1128 
1129  // Loop bounds and step.
1130  auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0);
1131  auto ub =
1132  maskDims->empty()
1133  ? Value(createVscaleMultiple(vectorType.getDimSize(0)))
1134  : vector::getAsValues(rewriter, loc, maskDims->front()).front();
1135  auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
1136 
1137  // Generate a new mask for the slice.
1138  VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
1139  Value sliceMask = nullptr;
1140  if (!maskDims->empty()) {
1141  sliceMask = vector::CreateMaskOp::create(
1142  rewriter, loc, sliceType.clone(rewriter.getI1Type()),
1143  ArrayRef<OpFoldResult>(*maskDims).drop_front());
1144  }
1145 
1146  Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{};
1147  ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
1148  auto result = scf::ForOp::create(
1149  rewriter, loc, lb, ub, step, initLoopArgs,
1150  [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
1151  // Indices for the new transfer op.
1152  SmallVector<Value, 8> xferIndices;
1153  getXferIndices(b, writeOp, iv, xferIndices);
1154 
1155  // Extract a transposed slice from the source vector.
1156  SmallVector<Value> transposeElements =
1157  llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1158  return vector::ExtractOp::create(
1159  b, loc, transposeSourceSlices[idx], iv);
1160  });
1161  auto sliceVec = vector::FromElementsOp::create(b, loc, sliceType,
1162  transposeElements);
1163 
1164  // Create the transfer_write for the slice.
1165  Value dest =
1166  loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
1167  auto newWriteOp = vector::TransferWriteOp::create(
1168  b, loc, sliceVec, dest, xferIndices,
1169  ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
1170  if (sliceMask)
1171  newWriteOp.getMaskMutable().assign(sliceMask);
1172 
1173  // Yield from the loop.
1174  scf::YieldOp::create(b, loc,
1175  loopIterArgs.empty() ? ValueRange{}
1176  : newWriteOp.getResult());
1177  });
1178 
1179  if (isTensorOp(writeOp))
1180  rewriter.replaceOp(writeOp, result);
1181  else
1182  rewriter.eraseOp(writeOp);
1183 
1184  return success();
1185  }
1186 };
1187 
1188 } // namespace lowering_n_d
1189 
1191 
1192 /// If the original transfer op has a mask, compute the mask of the new transfer
1193 /// op (for the current iteration `i`) and assign it.
1194 template <typename OpTy>
1195 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1196  int64_t i) {
1197  if (!xferOp.getMask())
1198  return;
1199 
1200  if (xferOp.isBroadcastDim(0)) {
1201  // To-be-unpacked dimension is a broadcast, which does not have a
1202  // corresponding mask dimension. Mask attribute remains unchanged.
1203  newXferOp.getMaskMutable().assign(xferOp.getMask());
1204  return;
1205  }
1206 
1207  if (xferOp.getMaskType().getRank() > 1) {
1208  // Unpack one dimension of the mask.
1209  OpBuilder::InsertionGuard guard(b);
1210  b.setInsertionPoint(newXferOp); // Insert load before newXfer.
1211 
1212  llvm::SmallVector<int64_t, 1> indices({i});
1213  Location loc = xferOp.getLoc();
1214  auto newMask = vector::ExtractOp::create(b, loc, xferOp.getMask(), indices);
1215  newXferOp.getMaskMutable().assign(newMask);
1216  }
1217 
1218  // If we end up here: The mask of the old transfer op is 1D and the unpacked
1219  // dim is not a broadcast, so no mask is needed on the new transfer op.
1220  // `generateInBoundsCheck` will have evaluated the mask already.
1221 }
1222 
1223 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
1224 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
1225 /// memref buffer is allocated and the SCF loop is fully unrolled.
1226 ///
1227 /// ```
1228 /// E.g.:
1229 /// ```
1230 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
1231 /// : memref<?x?x?xf32>, vector<5x4xf32>
1232 /// ```
1233 /// is rewritten to IR such as (simplified):
1234 /// ```
1235 /// %v_init = splat %padding : vector<5x4xf32>
1236 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
1237 /// : memref<?x?x?xf32>, vector<4xf32>
1238 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
1239 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
1240 /// : memref<?x?x?xf32>, vector<4xf32>
1241 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
1242 /// ...
1243 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
1244 /// : memref<?x?x?xf32>, vector<4xf32>
1245 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
1246 /// ```
1247 ///
1248 /// Note: As an optimization, if the result of the original TransferReadOp
1249 /// was directly inserted into another vector, no new %v_init vector is created.
1250 /// Instead, the new TransferReadOp results are inserted into that vector.
1251 struct UnrollTransferReadConversion
1252  : public VectorToSCFPattern<TransferReadOp> {
1253  using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1254 
1255  void initialize() {
1256  // This pattern recursively unpacks one dimension at a time. The recursion
1257  // bounded as the rank is strictly decreasing.
1258  setHasBoundedRewriteRecursion();
1259  }
1260 
1261  /// Get or build the vector into which the newly created TransferReadOp
1262  /// results are inserted.
1263  Value buildResultVector(PatternRewriter &rewriter,
1264  TransferReadOp xferOp) const {
1265  if (auto insertOp = getInsertOp(xferOp))
1266  return insertOp.getDest();
1267  Location loc = xferOp.getLoc();
1268  return vector::BroadcastOp::create(rewriter, loc, xferOp.getVectorType(),
1269  xferOp.getPadding());
1270  }
1271 
1272  /// If the result of the TransferReadOp has exactly one user, which is a
1273  /// vector::InsertOp, return that operation.
1274  vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
1275  if (xferOp->hasOneUse()) {
1276  Operation *xferOpUser = *xferOp->getUsers().begin();
1277  if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1278  return insertOp;
1279  }
1280 
1281  return vector::InsertOp();
1282  }
1283 
1284  /// If the result of the TransferReadOp has exactly one user, which is a
1285  /// vector::InsertOp, return that operation's indices.
1286  void getInsertionIndices(TransferReadOp xferOp,
1287  SmallVectorImpl<OpFoldResult> &indices) const {
1288  if (auto insertOp = getInsertOp(xferOp)) {
1289  auto pos = insertOp.getMixedPosition();
1290  indices.append(pos.begin(), pos.end());
1291  }
1292  }
1293 
1294  /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1295  /// accesses, and broadcasts and transposes in permutation maps.
1296  LogicalResult matchAndRewrite(TransferReadOp xferOp,
1297  PatternRewriter &rewriter) const override {
1298  if (xferOp.getVectorType().getRank() <= options.targetRank)
1299  return rewriter.notifyMatchFailure(
1300  xferOp, "vector rank is less or equal to target rank");
1301  if (failed(checkLowerTensors(xferOp, rewriter)))
1302  return failure();
1303  if (xferOp.getVectorType().getElementType() !=
1304  xferOp.getShapedType().getElementType())
1305  return rewriter.notifyMatchFailure(
1306  xferOp, "not yet supported: element type mismatch");
1307  auto xferVecType = xferOp.getVectorType();
1308  if (xferVecType.getScalableDims()[0]) {
1309  return rewriter.notifyMatchFailure(
1310  xferOp, "scalable dimensions cannot be unrolled at compile time");
1311  }
1312 
1313  auto insertOp = getInsertOp(xferOp);
1314  auto vec = buildResultVector(rewriter, xferOp);
1315  auto vecType = dyn_cast<VectorType>(vec.getType());
1316 
1317  VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
1318 
1319  int64_t dimSize = xferVecType.getShape()[0];
1320 
1321  // Generate fully unrolled loop of transfer ops.
1322  Location loc = xferOp.getLoc();
1323  for (int64_t i = 0; i < dimSize; ++i) {
1324  Value iv = arith::ConstantIndexOp::create(rewriter, loc, i);
1325 
1326  // FIXME: Rename this lambda - it does much more than just
1327  // in-bounds-check generation.
1328  vec = generateInBoundsCheck(
1329  rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
1330  /*inBoundsCase=*/
1331  [&](OpBuilder &b, Location loc) {
1332  // Indices for the new transfer op.
1333  SmallVector<Value, 8> xferIndices;
1334  getXferIndices(b, xferOp, iv, xferIndices);
1335 
1336  // Indices for the new vector.insert op.
1337  SmallVector<OpFoldResult, 8> insertionIndices;
1338  getInsertionIndices(xferOp, insertionIndices);
1339  insertionIndices.push_back(rewriter.getIndexAttr(i));
1340 
1341  auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1342 
1343  auto newXferOp = vector::TransferReadOp::create(
1344  b, loc, newXferVecType, xferOp.getBase(), xferIndices,
1345  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1346  xferOp.getPadding(), Value(), inBoundsAttr);
1347  maybeAssignMask(b, xferOp, newXferOp, i);
1348 
1349  Value valToInser = newXferOp.getResult();
1350  if (newXferVecType.getRank() == 0) {
1351  // vector.insert does not accept rank-0 as the non-indexed
1352  // argument. Extract the scalar before inserting.
1353  valToInser = vector::ExtractOp::create(b, loc, valToInser,
1355  }
1356  return vector::InsertOp::create(b, loc, valToInser, vec,
1357  insertionIndices);
1358  },
1359  /*outOfBoundsCase=*/
1360  [&](OpBuilder &b, Location loc) {
1361  // Loop through original (unmodified) vector.
1362  return vec;
1363  });
1364  }
1365 
1366  if (insertOp) {
1367  // Rewrite single user of the old TransferReadOp, which was an InsertOp.
1368  rewriter.replaceOp(insertOp, vec);
1369  rewriter.eraseOp(xferOp);
1370  } else {
1371  rewriter.replaceOp(xferOp, vec);
1372  }
1373 
1374  return success();
1375  }
1376 };
1377 
1378 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
1379 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
1380 /// memref buffer is allocated and the SCF loop is fully unrolled.
1381 ///
1382 /// ```
1383 /// E.g.:
1384 /// ```
1385 /// vector.transfer_write %vec, %A[%a, %b, %c]
1386 /// : vector<5x4xf32>, memref<?x?x?xf32>
1387 /// ```
1388 /// is rewritten to IR such as (simplified):
1389 /// ```
1390 /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32>
1391 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
1392 /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32>
1393 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
1394 /// ...
1395 /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32>
1396 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
1397 /// ```
1398 ///
1399 /// Note: As an optimization, if the vector of the original TransferWriteOp
1400 /// was directly extracted from another vector via an ExtractOp `a`, extract
1401 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
1402 /// doing so, `a` may become dead, and the number of ExtractOps generated during
1403 /// recursive application of this pattern will be minimal.
1404 struct UnrollTransferWriteConversion
1405  : public VectorToSCFPattern<TransferWriteOp> {
1406  using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1407 
1408  void initialize() {
1409  // This pattern recursively unpacks one dimension at a time. The recursion
1410  // bounded as the rank is strictly decreasing.
1411  setHasBoundedRewriteRecursion();
1412  }
1413 
1414  /// Return the vector from which newly generated ExtracOps will extract.
1415  Value getDataVector(TransferWriteOp xferOp) const {
1416  if (auto extractOp = getExtractOp(xferOp))
1417  return extractOp.getVector();
1418  return xferOp.getVector();
1419  }
1420 
1421  /// If the input of the given TransferWriteOp is an ExtractOp, return it.
1422  vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
1423  if (auto *op = xferOp.getVector().getDefiningOp())
1424  return dyn_cast<vector::ExtractOp>(op);
1425  return vector::ExtractOp();
1426  }
1427 
1428  /// If the input of the given TransferWriteOp is an ExtractOp, return its
1429  /// indices.
1430  void getExtractionIndices(TransferWriteOp xferOp,
1431  SmallVectorImpl<OpFoldResult> &indices) const {
1432  if (auto extractOp = getExtractOp(xferOp)) {
1433  auto pos = extractOp.getMixedPosition();
1434  indices.append(pos.begin(), pos.end());
1435  }
1436  }
1437 
1438  /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1439  /// accesses, and broadcasts and transposes in permutation maps.
1440  LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1441  PatternRewriter &rewriter) const override {
1442  VectorType inputVectorTy = xferOp.getVectorType();
1443 
1444  if (inputVectorTy.getRank() <= options.targetRank)
1445  return failure();
1446 
1447  if (failed(checkLowerTensors(xferOp, rewriter)))
1448  return failure();
1449  // Transfer ops that modify the element type are not supported atm.
1450  if (inputVectorTy.getElementType() !=
1451  xferOp.getShapedType().getElementType())
1452  return failure();
1453 
1454  auto vec = getDataVector(xferOp);
1455  if (inputVectorTy.getScalableDims()[0]) {
1456  // Cannot unroll a scalable dimension at compile time.
1457  return failure();
1458  }
1459 
1460  int64_t dimSize = inputVectorTy.getShape()[0];
1461  Value source = xferOp.getBase(); // memref or tensor to be written to.
1462  auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1463 
1464  // Generate fully unrolled loop of transfer ops.
1465  Location loc = xferOp.getLoc();
1466  for (int64_t i = 0; i < dimSize; ++i) {
1467  Value iv = arith::ConstantIndexOp::create(rewriter, loc, i);
1468 
1469  auto updatedSource = generateInBoundsCheck(
1470  rewriter, xferOp, iv, unpackedDim(xferOp),
1471  isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1472  /*inBoundsCase=*/
1473  [&](OpBuilder &b, Location loc) {
1474  // Indices for the new transfer op.
1475  SmallVector<Value, 8> xferIndices;
1476  getXferIndices(b, xferOp, iv, xferIndices);
1477 
1478  // Indices for the new vector.extract op.
1479  SmallVector<OpFoldResult, 8> extractionIndices;
1480  getExtractionIndices(xferOp, extractionIndices);
1481  extractionIndices.push_back(b.getI64IntegerAttr(i));
1482 
1483  auto extracted =
1484  vector::ExtractOp::create(b, loc, vec, extractionIndices);
1485  auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1486  Value xferVec;
1487  if (inputVectorTy.getRank() == 1) {
1488  // When target-rank=0, unrolling would causes the vector input
1489  // argument into `transfer_write` to become a scalar. We solve
1490  // this by broadcasting the scalar to a 0D vector.
1491  xferVec = vector::BroadcastOp::create(
1492  b, loc, VectorType::get({}, extracted.getType()), extracted);
1493  } else {
1494  xferVec = extracted;
1495  }
1496  auto newXferOp = vector::TransferWriteOp::create(
1497  b, loc, sourceType, xferVec, source, xferIndices,
1498  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1499  inBoundsAttr);
1500 
1501  maybeAssignMask(b, xferOp, newXferOp, i);
1502 
1503  return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1504  },
1505  /*outOfBoundsCase=*/
1506  [&](OpBuilder &b, Location loc) {
1507  return isTensorOp(xferOp) ? source : Value();
1508  });
1509 
1510  if (isTensorOp(xferOp))
1511  source = updatedSource;
1512  }
1513 
1514  if (isTensorOp(xferOp))
1515  rewriter.replaceOp(xferOp, source);
1516  else
1517  rewriter.eraseOp(xferOp);
1518 
1519  return success();
1520  }
1521 };
1522 
1523 } // namespace lowering_n_d_unrolled
1524 
1525 namespace lowering_1_d {
1526 
1527 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1528 /// part of TransferOp1dConversion. Return the memref dimension on which
1529 /// the transfer is operating. A return value of std::nullopt indicates a
1530 /// broadcast.
1531 template <typename OpTy>
1532 static std::optional<int64_t>
1533 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1534  SmallVector<Value, 8> &memrefIndices) {
1535  auto indices = xferOp.getIndices();
1536  auto map = xferOp.getPermutationMap();
1537  assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1538 
1539  memrefIndices.append(indices.begin(), indices.end());
1540  assert(map.getNumResults() == 1 &&
1541  "Expected 1 permutation map result for 1D transfer");
1542  if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1543  Location loc = xferOp.getLoc();
1544  auto dim = expr.getPosition();
1545  AffineExpr d0, d1;
1546  bindDims(xferOp.getContext(), d0, d1);
1547  Value offset = memrefIndices[dim];
1548  memrefIndices[dim] =
1549  affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1550  return dim;
1551  }
1552 
1553  assert(xferOp.isBroadcastDim(0) &&
1554  "Expected AffineDimExpr or AffineConstantExpr");
1555  return std::nullopt;
1556 }
1557 
1558 /// Codegen strategy for TransferOp1dConversion, depending on the
1559 /// operation.
1560 template <typename OpTy>
1561 struct Strategy1d;
1562 
1563 /// Codegen strategy for TransferReadOp.
1564 template <>
1565 struct Strategy1d<TransferReadOp> {
1566  static void generateForLoopBody(OpBuilder &b, Location loc,
1567  TransferReadOp xferOp, Value iv,
1568  ValueRange loopState) {
1569  SmallVector<Value, 8> indices;
1570  auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1571  auto vec = loopState[0];
1572 
1573  // In case of out-of-bounds access, leave `vec` as is (was initialized with
1574  // padding value).
1575  auto nextVec = generateInBoundsCheck(
1576  b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1577  /*inBoundsCase=*/
1578  [&](OpBuilder &b, Location loc) {
1579  Value val = memref::LoadOp::create(b, loc, xferOp.getBase(), indices);
1580  return vector::InsertOp::create(b, loc, val, vec, iv);
1581  },
1582  /*outOfBoundsCase=*/
1583  [&](OpBuilder & /*b*/, Location loc) { return vec; });
1584  scf::YieldOp::create(b, loc, nextVec);
1585  }
1586 
1587  static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1588  // Inititalize vector with padding value.
1589  Location loc = xferOp.getLoc();
1590  return vector::BroadcastOp::create(b, loc, xferOp.getVectorType(),
1591  xferOp.getPadding());
1592  }
1593 };
1594 
1595 /// Codegen strategy for TransferWriteOp.
1596 template <>
1597 struct Strategy1d<TransferWriteOp> {
1598  static void generateForLoopBody(OpBuilder &b, Location loc,
1599  TransferWriteOp xferOp, Value iv,
1600  ValueRange /*loopState*/) {
1601  SmallVector<Value, 8> indices;
1602  auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1603 
1604  // Nothing to do in case of out-of-bounds access.
1605  generateInBoundsCheck(
1606  b, xferOp, iv, dim,
1607  /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1608  auto val = vector::ExtractOp::create(b, loc, xferOp.getVector(), iv);
1609  memref::StoreOp::create(b, loc, val, xferOp.getBase(), indices);
1610  });
1611  scf::YieldOp::create(b, loc);
1612  }
1613 
1614  static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1615  return Value();
1616  }
1617 };
1618 
1619 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1620 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1621 /// vector load/stores due to non-unit strides or broadcasts:
1622 ///
1623 /// * Transfer dimension is not the last memref dimension
1624 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1625 /// * Memref has a layout map with non-unit stride on the last dimension
1626 ///
1627 /// This pattern generates IR as follows:
1628 ///
1629 /// 1. Generate a for loop iterating over each vector element.
1630 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1631 /// depending on OpTy.
1632 ///
1633 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1634 /// can be generated instead of TransferOp1dConversion. Add such a pattern
1635 /// to ConvertVectorToLLVM.
1636 ///
1637 /// E.g.:
1638 /// ```
1639 /// vector.transfer_write %vec, %A[%a, %b]
1640 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1641 /// : vector<9xf32>, memref<?x?xf32>
1642 /// ```
1643 /// Is rewritten to approximately the following pseudo-IR:
1644 /// ```
1645 /// for i = 0 to 9 {
1646 /// %t = vector.extract %vec[i] : f32 from vector<9xf32>
1647 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1648 /// }
1649 /// ```
1650 template <typename OpTy>
1651 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1652  using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1653 
1654  LogicalResult matchAndRewrite(OpTy xferOp,
1655  PatternRewriter &rewriter) const override {
1656  // TODO: support 0-d corner case.
1657  if (xferOp.getTransferRank() == 0)
1658  return failure();
1659  auto map = xferOp.getPermutationMap();
1660  auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1661 
1662  if (!memRefType)
1663  return failure();
1664  if (xferOp.getVectorType().getRank() != 1)
1665  return failure();
1666  if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
1667  return failure(); // Handled by ConvertVectorToLLVM
1668 
1669  // Loop bounds, step, state...
1670  Location loc = xferOp.getLoc();
1671  auto vecType = xferOp.getVectorType();
1672  auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0);
1673  Value ub =
1674  arith::ConstantIndexOp::create(rewriter, loc, vecType.getDimSize(0));
1675  if (vecType.isScalable()) {
1676  Value vscale =
1677  vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
1678  ub = arith::MulIOp::create(rewriter, loc, ub, vscale);
1679  }
1680  auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
1681  auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1682 
1683  // Generate for loop.
1684  rewriter.replaceOpWithNewOp<scf::ForOp>(
1685  xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1686  [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1687  Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1688  });
1689 
1690  return success();
1691  }
1692 };
1693 
1694 } // namespace lowering_1_d
1695 } // namespace
1696 
1699  if (options.unroll) {
1700  patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1701  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1702  patterns.getContext(), options);
1703  } else {
1704  patterns.add<lowering_n_d::PrepareTransferReadConversion,
1705  lowering_n_d::PrepareTransferWriteConversion,
1706  lowering_n_d::TransferOpConversion<TransferReadOp>,
1707  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1708  patterns.getContext(), options);
1709  }
1710  if (options.lowerScalable) {
1711  patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1712  patterns.getContext(), options);
1713  }
1714  if (options.targetRank == 1) {
1715  patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1716  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1717  patterns.getContext(), options);
1718  }
1719  patterns.add<lowering_n_d::DecomposePrintOpConversion>(patterns.getContext(),
1720  options);
1721 }
1722 
1723 namespace {
1724 
1725 struct ConvertVectorToSCFPass
1726  : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1727  ConvertVectorToSCFPass() = default;
1728  ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1729  this->fullUnroll = options.unroll;
1730  this->targetRank = options.targetRank;
1731  this->lowerTensors = options.lowerTensors;
1732  this->lowerScalable = options.lowerScalable;
1733  }
1734 
1735  void runOnOperation() override {
1737  options.unroll = fullUnroll;
1738  options.targetRank = targetRank;
1739  options.lowerTensors = lowerTensors;
1740  options.lowerScalable = lowerScalable;
1741 
1742  // Lower permutation maps first.
1743  RewritePatternSet lowerTransferPatterns(&getContext());
1745  lowerTransferPatterns);
1746  (void)applyPatternsGreedily(getOperation(),
1747  std::move(lowerTransferPatterns));
1748 
1751  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
1752  }
1753 };
1754 
1755 } // namespace
1756 
1757 std::unique_ptr<Pass>
1759  return std::make_unique<ConvertVectorToSCFPass>(options);
1760 }
MLIR_CRUNNERUTILS_EXPORT void printClose()
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition: Unit.cpp:18
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Definition: VectorToGPU.cpp:52
static Operation * getAutomaticAllocationScope(Operation *op)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
UnitAttr getUnitAttr()
Definition: Builders.cpp:93
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
A trait of region holding operations that define a new scope for automatic allocations,...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
Block & front()
Definition: Region.h:65
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:311
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1276
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:39
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:369
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Definition: VectorUtils.h:118
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Create a pass to convert a subset of vector ops to SCF.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...
Definition: VectorToSCF.h:52