MLIR  17.0.0git
VectorTransferSplitRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent patterns to rewrite a vector.transfer
10 // op into a fully in-bounds part and a partial part.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <optional>
15 #include <type_traits>
16 
23 
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
28 
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 #define DEBUG_TYPE "vector-transfer-split"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 static std::optional<int64_t> extractConstantIndex(Value v) {
42  if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
43  return cstOp.value();
44  if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
45  if (affineApplyOp.getAffineMap().isSingleConstant())
46  return affineApplyOp.getAffineMap().getSingleConstantResult();
47  return std::nullopt;
48 }
49 
50 // Missing foldings of scf.if make it necessary to perform poor man's folding
51 // eagerly, especially in the case of unrolling. In the future, this should go
52 // away once scf.if folds properly.
54  auto maybeCstV = extractConstantIndex(v);
55  auto maybeCstUb = extractConstantIndex(ub);
56  if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
57  return Value();
58  return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
59 }
60 
61 /// Build the condition to ensure that a particular VectorTransferOpInterface
62 /// is in-bounds.
64  VectorTransferOpInterface xferOp) {
65  assert(xferOp.permutation_map().isMinorIdentity() &&
66  "Expected minor identity map");
67  Value inBoundsCond;
68  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
69  // Zip over the resulting vector shape and memref indices.
70  // If the dimension is known to be in-bounds, it does not participate in
71  // the construction of `inBoundsCond`.
72  if (xferOp.isDimInBounds(resultIdx))
73  return;
74  // Fold or create the check that `index + vector_size` <= `memref_size`.
75  Location loc = xferOp.getLoc();
76  int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
77  auto d0 = getAffineDimExpr(0, xferOp.getContext());
78  auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
79  Value sum =
80  makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
81  Value cond = createFoldedSLE(
82  b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
83  if (!cond)
84  return;
85  // Conjunction over all dims for which we are in-bounds.
86  if (inBoundsCond)
87  inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
88  else
89  inBoundsCond = cond;
90  });
91  return inBoundsCond;
92 }
93 
94 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
95 /// masking) fast path and a slow path.
96 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
97 /// newly created conditional upon function return.
98 /// To accommodate for the fact that the original vector.transfer indexing may
99 /// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
100 /// scf.if op returns a view and values of type index.
101 /// At this time, only vector.transfer_read case is implemented.
102 ///
103 /// Example (a 2-D vector.transfer_read):
104 /// ```
105 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
106 /// ```
107 /// is transformed into:
108 /// ```
109 /// %1:3 = scf.if (%inBounds) {
110 /// // fast path, direct cast
111 /// memref.cast %A: memref<A...> to compatibleMemRefType
112 /// scf.yield %view : compatibleMemRefType, index, index
113 /// } else {
114 /// // slow path, not in-bounds vector.transfer or linalg.copy.
115 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
116 /// scf.yield %4 : compatibleMemRefType, index, index
117 // }
118 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
119 /// ```
120 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
121 ///
122 /// Preconditions:
123 /// 1. `xferOp.permutation_map()` must be a minor identity map
124 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
125 /// must be equal. This will be relaxed in the future but requires
126 /// rank-reducing subviews.
127 static LogicalResult
128 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
129  // TODO: support 0-d corner case.
130  if (xferOp.getTransferRank() == 0)
131  return failure();
132 
133  // TODO: expand support to these 2 cases.
134  if (!xferOp.permutation_map().isMinorIdentity())
135  return failure();
136  // Must have some out-of-bounds dimension to be a candidate for splitting.
137  if (!xferOp.hasOutOfBoundsDim())
138  return failure();
139  // Don't split transfer operations directly under IfOp, this avoids applying
140  // the pattern recursively.
141  // TODO: improve the filtering condition to make it more applicable.
142  if (isa<scf::IfOp>(xferOp->getParentOp()))
143  return failure();
144  return success();
145 }
146 
147 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
148 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
149 /// return null; otherwise:
150 /// 1. if `aT` and `bT` are cast-compatible, return `aT`.
151 /// 2. else return a new MemRefType obtained by iterating over the shape and
152 /// strides and:
153 /// a. keeping the ones that are static and equal across `aT` and `bT`.
154 /// b. using a dynamic shape and/or stride for the dimensions that don't
155 /// agree.
156 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
157  if (memref::CastOp::areCastCompatible(aT, bT))
158  return aT;
159  if (aT.getRank() != bT.getRank())
160  return MemRefType();
161  int64_t aOffset, bOffset;
162  SmallVector<int64_t, 4> aStrides, bStrides;
163  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
164  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
165  aStrides.size() != bStrides.size())
166  return MemRefType();
167 
168  ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
169  int64_t resOffset;
170  SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
171  resStrides(bT.getRank(), 0);
172  for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
173  resShape[idx] =
174  (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
175  resStrides[idx] =
176  (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
177  }
178  resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
179  return MemRefType::get(
180  resShape, aT.getElementType(),
181  StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
182 }
183 
184 /// Operates under a scoped context to build the intersection between the
185 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
186 // TODO: view intersection/union/differences should be a proper std op.
187 static std::pair<Value, Value>
188 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
189  Value alloc) {
190  Location loc = xferOp.getLoc();
191  int64_t memrefRank = xferOp.getShapedType().getRank();
192  // TODO: relax this precondition, will require rank-reducing subviews.
193  assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
194  "Expected memref rank to match the alloc rank");
195  ValueRange leadingIndices =
196  xferOp.indices().take_front(xferOp.getLeadingShapedRank());
198  sizes.append(leadingIndices.begin(), leadingIndices.end());
199  auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
200  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
201  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
202  Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
203  xferOp.source(), indicesIdx);
204  Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
205  Value index = xferOp.indices()[indicesIdx];
206  AffineExpr i, j, k;
207  bindDims(xferOp.getContext(), i, j, k);
209  AffineMap::inferFromExprList(MapList{{i - j, k}});
210  // affine_min(%dimMemRef - %index, %dimAlloc)
211  Value affineMin = b.create<AffineMinOp>(
212  loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
213  sizes.push_back(affineMin);
214  });
215 
216  SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
217  xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
218  SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
219  SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
220  auto copySrc = b.create<memref::SubViewOp>(
221  loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
222  auto copyDest = b.create<memref::SubViewOp>(
223  loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
224  return std::make_pair(copySrc, copyDest);
225 }
226 
227 /// Given an `xferOp` for which:
228 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
229 /// 2. a memref of single vector `alloc` has been allocated.
230 /// Produce IR resembling:
231 /// ```
232 /// %1:3 = scf.if (%inBounds) {
233 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType
234 /// scf.yield %view, ... : compatibleMemRefType, index, index
235 /// } else {
236 /// %2 = linalg.fill(%pad, %alloc)
237 /// %3 = subview %view [...][...][...]
238 /// %4 = subview %alloc [0, 0] [...] [...]
239 /// linalg.copy(%3, %4)
240 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
241 /// scf.yield %5, ... : compatibleMemRefType, index, index
242 /// }
243 /// ```
244 /// Return the produced scf::IfOp.
245 static scf::IfOp
246 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
247  TypeRange returnTypes, Value inBoundsCond,
248  MemRefType compatibleMemRefType, Value alloc) {
249  Location loc = xferOp.getLoc();
250  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
251  Value memref = xferOp.getSource();
252  return b.create<scf::IfOp>(
253  loc, inBoundsCond,
254  [&](OpBuilder &b, Location loc) {
255  Value res = memref;
256  if (compatibleMemRefType != xferOp.getShapedType())
257  res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
258  scf::ValueVector viewAndIndices{res};
259  viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
260  xferOp.getIndices().end());
261  b.create<scf::YieldOp>(loc, viewAndIndices);
262  },
263  [&](OpBuilder &b, Location loc) {
264  b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
265  ValueRange{alloc});
266  // Take partial subview of memref which guarantees no dimension
267  // overflows.
268  IRRewriter rewriter(b);
269  std::pair<Value, Value> copyArgs = createSubViewIntersection(
270  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
271  alloc);
272  b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
273  Value casted =
274  b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
275  scf::ValueVector viewAndIndices{casted};
276  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
277  zero);
278  b.create<scf::YieldOp>(loc, viewAndIndices);
279  });
280 }
281 
282 /// Given an `xferOp` for which:
283 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
284 /// 2. a memref of single vector `alloc` has been allocated.
285 /// Produce IR resembling:
286 /// ```
287 /// %1:3 = scf.if (%inBounds) {
288 /// memref.cast %A: memref<A...> to compatibleMemRefType
289 /// scf.yield %view, ... : compatibleMemRefType, index, index
290 /// } else {
291 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
292 /// %3 = vector.type_cast %extra_alloc :
293 /// memref<...> to memref<vector<...>>
294 /// store %2, %3[] : memref<vector<...>>
295 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
296 /// scf.yield %4, ... : compatibleMemRefType, index, index
297 /// }
298 /// ```
299 /// Return the produced scf::IfOp.
301  RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
302  Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
303  Location loc = xferOp.getLoc();
304  scf::IfOp fullPartialIfOp;
305  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
306  Value memref = xferOp.getSource();
307  return b.create<scf::IfOp>(
308  loc, inBoundsCond,
309  [&](OpBuilder &b, Location loc) {
310  Value res = memref;
311  if (compatibleMemRefType != xferOp.getShapedType())
312  res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
313  scf::ValueVector viewAndIndices{res};
314  viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
315  xferOp.getIndices().end());
316  b.create<scf::YieldOp>(loc, viewAndIndices);
317  },
318  [&](OpBuilder &b, Location loc) {
319  Operation *newXfer = b.clone(*xferOp.getOperation());
320  Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
321  b.create<memref::StoreOp>(
322  loc, vector,
323  b.create<vector::TypeCastOp>(
324  loc, MemRefType::get({}, vector.getType()), alloc));
325 
326  Value casted =
327  b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
328  scf::ValueVector viewAndIndices{casted};
329  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
330  zero);
331  b.create<scf::YieldOp>(loc, viewAndIndices);
332  });
333 }
334 
335 /// Given an `xferOp` for which:
336 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
337 /// 2. a memref of single vector `alloc` has been allocated.
338 /// Produce IR resembling:
339 /// ```
340 /// %1:3 = scf.if (%inBounds) {
341 /// memref.cast %A: memref<A...> to compatibleMemRefType
342 /// scf.yield %view, ... : compatibleMemRefType, index, index
343 /// } else {
344 /// %3 = vector.type_cast %extra_alloc :
345 /// memref<...> to memref<vector<...>>
346 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
347 /// scf.yield %4, ... : compatibleMemRefType, index, index
348 /// }
349 /// ```
350 static ValueRange
351 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
352  TypeRange returnTypes, Value inBoundsCond,
353  MemRefType compatibleMemRefType, Value alloc) {
354  Location loc = xferOp.getLoc();
355  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
356  Value memref = xferOp.getSource();
357  return b
358  .create<scf::IfOp>(
359  loc, inBoundsCond,
360  [&](OpBuilder &b, Location loc) {
361  Value res = memref;
362  if (compatibleMemRefType != xferOp.getShapedType())
363  res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
364  scf::ValueVector viewAndIndices{res};
365  viewAndIndices.insert(viewAndIndices.end(),
366  xferOp.getIndices().begin(),
367  xferOp.getIndices().end());
368  b.create<scf::YieldOp>(loc, viewAndIndices);
369  },
370  [&](OpBuilder &b, Location loc) {
371  Value casted =
372  b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
373  scf::ValueVector viewAndIndices{casted};
374  viewAndIndices.insert(viewAndIndices.end(),
375  xferOp.getTransferRank(), zero);
376  b.create<scf::YieldOp>(loc, viewAndIndices);
377  })
378  ->getResults();
379 }
380 
381 /// Given an `xferOp` for which:
382 /// 1. `inBoundsCond` has been computed.
383 /// 2. a memref of single vector `alloc` has been allocated.
384 /// 3. it originally wrote to %view
385 /// Produce IR resembling:
386 /// ```
387 /// %notInBounds = arith.xori %inBounds, %true
388 /// scf.if (%notInBounds) {
389 /// %3 = subview %alloc [...][...][...]
390 /// %4 = subview %view [0, 0][...][...]
391 /// linalg.copy(%3, %4)
392 /// }
393 /// ```
395  vector::TransferWriteOp xferOp,
396  Value inBoundsCond, Value alloc) {
397  Location loc = xferOp.getLoc();
398  auto notInBounds = b.create<arith::XOrIOp>(
399  loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
400  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
401  IRRewriter rewriter(b);
402  std::pair<Value, Value> copyArgs = createSubViewIntersection(
403  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
404  alloc);
405  b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
406  b.create<scf::YieldOp>(loc, ValueRange{});
407  });
408 }
409 
410 /// Given an `xferOp` for which:
411 /// 1. `inBoundsCond` has been computed.
412 /// 2. a memref of single vector `alloc` has been allocated.
413 /// 3. it originally wrote to %view
414 /// Produce IR resembling:
415 /// ```
416 /// %notInBounds = arith.xori %inBounds, %true
417 /// scf.if (%notInBounds) {
418 /// %2 = load %alloc : memref<vector<...>>
419 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
420 /// }
421 /// ```
423  vector::TransferWriteOp xferOp,
424  Value inBoundsCond,
425  Value alloc) {
426  Location loc = xferOp.getLoc();
427  auto notInBounds = b.create<arith::XOrIOp>(
428  loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
429  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
430  IRMapping mapping;
431  Value load = b.create<memref::LoadOp>(
432  loc,
433  b.create<vector::TypeCastOp>(
434  loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
435  ValueRange());
436  mapping.map(xferOp.getVector(), load);
437  b.clone(*xferOp.getOperation(), mapping);
438  b.create<scf::YieldOp>(loc, ValueRange{});
439  });
440 }
441 
442 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
444  // Find the closest surrounding allocation scope that is not a known looping
445  // construct (putting alloca's in loops doesn't always lower to deallocation
446  // until the end of the loop).
447  Operation *scope = nullptr;
448  for (Operation *parent = op->getParentOp(); parent != nullptr;
449  parent = parent->getParentOp()) {
450  if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
451  scope = parent;
452  if (!isa<scf::ForOp, AffineForOp>(parent))
453  break;
454  }
455  assert(scope && "Expected op to be inside automatic allocation scope");
456  return scope;
457 }
458 
459 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
460 /// masking) fastpath and a slowpath.
461 ///
462 /// For vector.transfer_read:
463 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
464 /// newly created conditional upon function return.
465 /// To accomodate for the fact that the original vector.transfer indexing may be
466 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
467 /// scf.if op returns a view and values of type index.
468 ///
469 /// Example (a 2-D vector.transfer_read):
470 /// ```
471 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
472 /// ```
473 /// is transformed into:
474 /// ```
475 /// %1:3 = scf.if (%inBounds) {
476 /// // fastpath, direct cast
477 /// memref.cast %A: memref<A...> to compatibleMemRefType
478 /// scf.yield %view : compatibleMemRefType, index, index
479 /// } else {
480 /// // slowpath, not in-bounds vector.transfer or linalg.copy.
481 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
482 /// scf.yield %4 : compatibleMemRefType, index, index
483 // }
484 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
485 /// ```
486 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
487 ///
488 /// For vector.transfer_write:
489 /// There are 2 conditional blocks. First a block to decide which memref and
490 /// indices to use for an unmasked, inbounds write. Then a conditional block to
491 /// further copy a partial buffer into the final result in the slow path case.
492 ///
493 /// Example (a 2-D vector.transfer_write):
494 /// ```
495 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
496 /// ```
497 /// is transformed into:
498 /// ```
499 /// %1:3 = scf.if (%inBounds) {
500 /// memref.cast %A: memref<A...> to compatibleMemRefType
501 /// scf.yield %view : compatibleMemRefType, index, index
502 /// } else {
503 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
504 /// scf.yield %4 : compatibleMemRefType, index, index
505 /// }
506 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
507 /// true]}
508 /// scf.if (%notInBounds) {
509 /// // slowpath: not in-bounds vector.transfer or linalg.copy.
510 /// }
511 /// ```
512 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
513 ///
514 /// Preconditions:
515 /// 1. `xferOp.permutation_map()` must be a minor identity map
516 /// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
517 /// must be equal. This will be relaxed in the future but requires
518 /// rank-reducing subviews.
520  RewriterBase &b, VectorTransferOpInterface xferOp,
521  VectorTransformsOptions options, scf::IfOp *ifOp) {
522  if (options.vectorTransferSplit == VectorTransferSplit::None)
523  return failure();
524 
525  SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
526  auto inBoundsAttr = b.getBoolArrayAttr(bools);
527  if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
528  b.updateRootInPlace(xferOp, [&]() {
529  xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
530  });
531  return success();
532  }
533 
534  // Assert preconditions. Additionally, keep the variables in an inner scope to
535  // ensure they aren't used in the wrong scopes further down.
536  {
538  "Expected splitFullAndPartialTransferPrecondition to hold");
539 
540  auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
541  auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
542 
543  if (!(xferReadOp || xferWriteOp))
544  return failure();
545  if (xferWriteOp && xferWriteOp.getMask())
546  return failure();
547  if (xferReadOp && xferReadOp.getMask())
548  return failure();
549  }
550 
552  b.setInsertionPoint(xferOp);
553  Value inBoundsCond = createInBoundsCond(
554  b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
555  if (!inBoundsCond)
556  return failure();
557 
558  // Top of the function `alloc` for transient storage.
559  Value alloc;
560  {
562  Operation *scope = getAutomaticAllocationScope(xferOp);
563  assert(scope->getNumRegions() == 1 &&
564  "AutomaticAllocationScope with >1 regions");
565  b.setInsertionPointToStart(&scope->getRegion(0).front());
566  auto shape = xferOp.getVectorType().getShape();
567  Type elementType = xferOp.getVectorType().getElementType();
568  alloc = b.create<memref::AllocaOp>(scope->getLoc(),
569  MemRefType::get(shape, elementType),
570  ValueRange{}, b.getI64IntegerAttr(32));
571  }
572 
573  MemRefType compatibleMemRefType =
574  getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
575  alloc.getType().cast<MemRefType>());
576  if (!compatibleMemRefType)
577  return failure();
578 
579  SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
580  b.getIndexType());
581  returnTypes[0] = compatibleMemRefType;
582 
583  if (auto xferReadOp =
584  dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
585  // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
586  scf::IfOp fullPartialIfOp =
587  options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
588  ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
589  inBoundsCond,
590  compatibleMemRefType, alloc)
591  : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
592  inBoundsCond, compatibleMemRefType,
593  alloc);
594  if (ifOp)
595  *ifOp = fullPartialIfOp;
596 
597  // Set existing read op to in-bounds, it always reads from a full buffer.
598  for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
599  xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
600 
601  b.updateRootInPlace(xferOp, [&]() {
602  xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
603  });
604 
605  return success();
606  }
607 
608  auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
609 
610  // Decide which location to write the entire vector to.
611  auto memrefAndIndices = getLocationToWriteFullVec(
612  b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
613 
614  // Do an in bounds write to either the output or the extra allocated buffer.
615  // The operation is cloned to prevent deleting information needed for the
616  // later IR creation.
617  IRMapping mapping;
618  mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
619  mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
620  auto *clone = b.clone(*xferWriteOp, mapping);
621  clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
622 
623  // Create a potential copy from the allocated buffer to the final output in
624  // the slow path case.
625  if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
626  createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
627  else
628  createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
629 
630  b.eraseOp(xferOp);
631 
632  return success();
633 }
634 
635 namespace {
636 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
637 /// may take an extra filter to perform selection at a finer granularity.
638 struct VectorTransferFullPartialRewriter : public RewritePattern {
639  using FilterConstraintType =
640  std::function<LogicalResult(VectorTransferOpInterface op)>;
641 
642  explicit VectorTransferFullPartialRewriter(
643  MLIRContext *context,
645  FilterConstraintType filter =
646  [](VectorTransferOpInterface op) { return success(); },
647  PatternBenefit benefit = 1)
648  : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
649  filter(std::move(filter)) {}
650 
651  /// Performs the rewrite.
652  LogicalResult matchAndRewrite(Operation *op,
653  PatternRewriter &rewriter) const override;
654 
655 private:
657  FilterConstraintType filter;
658 };
659 
660 } // namespace
661 
662 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
663  Operation *op, PatternRewriter &rewriter) const {
664  auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
665  if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
666  failed(filter(xferOp)))
667  return failure();
668  return splitFullAndPartialTransfer(rewriter, xferOp, options);
669 }
670 
673  patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
674  options);
675 }
@ None
static llvm::ManagedStatic< PassManagerOptions > options
static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)
Given an xferOp for which:
static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fast path and a ...
static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)
Operates under a scoped context to build the intersection between the view xferOp....
static std::optional< int64_t > extractConstantIndex(Value v)
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)
Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
static Value createFoldedSLE(RewriterBase &b, Value v, Value ub)
static Operation * getAutomaticAllocationScope(Operation *op)
Base type for affine expression.
Definition: AffineExpr.h:68
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:242
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:121
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:125
IndexType getIndexType()
Definition: Builders.cpp:68
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:263
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:651
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:520
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:412
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
This class represents a single result from folding an operation.
Definition: OpDefinition.h:235
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:537
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:218
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:550
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:457
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:549
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:74
LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:37
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:329
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1171
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:527
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Structure to control the behavior of vector transform patterns.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.