MLIR  22.0.0git
VectorTransferSplitRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent patterns to rewrite a vector.transfer
10 // op into a fully in-bounds part and a partial part.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <optional>
15 
22 
24 #include "mlir/IR/PatternMatch.h"
26 
27 #include "llvm/ADT/STLExtras.h"
28 
29 #define DEBUG_TYPE "vector-transfer-split"
30 
31 using namespace mlir;
32 using namespace mlir::vector;
33 
34 /// Build the condition to ensure that a particular VectorTransferOpInterface
35 /// is in-bounds.
37  VectorTransferOpInterface xferOp) {
38  assert(xferOp.getPermutationMap().isMinorIdentity() &&
39  "Expected minor identity map");
40  Value inBoundsCond;
41  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
42  // Zip over the resulting vector shape and memref indices.
43  // If the dimension is known to be in-bounds, it does not participate in
44  // the construction of `inBoundsCond`.
45  if (xferOp.isDimInBounds(resultIdx))
46  return;
47  // Fold or create the check that `index + vector_size` <= `memref_size`.
48  Location loc = xferOp.getLoc();
49  int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
51  b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
52  {xferOp.getIndices()[indicesIdx]});
53  OpFoldResult dimSz =
54  memref::getMixedSize(b, loc, xferOp.getBase(), indicesIdx);
55  auto maybeCstSum = getConstantIntValue(sum);
56  auto maybeCstDimSz = getConstantIntValue(dimSz);
57  if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
58  return;
59  Value cond =
60  arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sle,
62  getValueOrCreateConstantIndexOp(b, loc, dimSz));
63  // Conjunction over all dims for which we are in-bounds.
64  if (inBoundsCond)
65  inBoundsCond = arith::AndIOp::create(b, loc, inBoundsCond, cond);
66  else
67  inBoundsCond = cond;
68  });
69  return inBoundsCond;
70 }
71 
72 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
73 /// masking) fast path and a slow path.
74 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
75 /// newly created conditional upon function return.
76 /// To accommodate for the fact that the original vector.transfer indexing may
77 /// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
78 /// scf.if op returns a view and values of type index.
79 /// At this time, only vector.transfer_read case is implemented.
80 ///
81 /// Example (a 2-D vector.transfer_read):
82 /// ```
83 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
84 /// ```
85 /// is transformed into:
86 /// ```
87 /// %1:3 = scf.if (%inBounds) {
88 /// // fast path, direct cast
89 /// memref.cast %A: memref<A...> to compatibleMemRefType
90 /// scf.yield %view : compatibleMemRefType, index, index
91 /// } else {
92 /// // slow path, not in-bounds vector.transfer or linalg.copy.
93 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
94 /// scf.yield %4 : compatibleMemRefType, index, index
95 // }
96 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
97 /// ```
98 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
99 ///
100 /// Preconditions:
101 /// 1. `xferOp.getPermutationMap()` must be a minor identity map
102 /// 2. the rank of the `xferOp.memref()` and the rank of the
103 /// `xferOp.getVector()` must be equal. This will be relaxed in the future
104 /// but requires rank-reducing subviews.
105 static LogicalResult
106 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
107  // TODO: support 0-d corner case.
108  if (xferOp.getTransferRank() == 0)
109  return failure();
110 
111  // TODO: expand support to these 2 cases.
112  if (!xferOp.getPermutationMap().isMinorIdentity())
113  return failure();
114  // Must have some out-of-bounds dimension to be a candidate for splitting.
115  if (!xferOp.hasOutOfBoundsDim())
116  return failure();
117  // Don't split transfer operations directly under IfOp, this avoids applying
118  // the pattern recursively.
119  // TODO: improve the filtering condition to make it more applicable.
120  if (isa<scf::IfOp>(xferOp->getParentOp()))
121  return failure();
122  return success();
123 }
124 
125 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
126 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
127 /// return null; otherwise:
128 /// 1. if `aT` and `bT` are cast-compatible, return `aT`.
129 /// 2. else return a new MemRefType obtained by iterating over the shape and
130 /// strides and:
131 /// a. keeping the ones that are static and equal across `aT` and `bT`.
132 /// b. using a dynamic shape and/or stride for the dimensions that don't
133 /// agree.
134 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
135  if (memref::CastOp::areCastCompatible(aT, bT))
136  return aT;
137  if (aT.getRank() != bT.getRank())
138  return MemRefType();
139  int64_t aOffset, bOffset;
140  SmallVector<int64_t, 4> aStrides, bStrides;
141  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
142  failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
143  aStrides.size() != bStrides.size())
144  return MemRefType();
145 
146  ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
147  int64_t resOffset;
148  SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
149  resStrides(bT.getRank(), 0);
150  for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
151  resShape[idx] =
152  (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
153  resStrides[idx] =
154  (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
155  }
156  resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
157  return MemRefType::get(
158  resShape, aT.getElementType(),
159  StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
160 }
161 
162 /// Casts the given memref to a compatible memref type. If the source memref has
163 /// a different address space than the target type, a `memref.memory_space_cast`
164 /// is first inserted, followed by a `memref.cast`.
166  MemRefType compatibleMemRefType) {
167  MemRefType sourceType = cast<MemRefType>(memref.getType());
168  Value res = memref;
169  if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
170  sourceType = MemRefType::get(
171  sourceType.getShape(), sourceType.getElementType(),
172  sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
173  res =
174  memref::MemorySpaceCastOp::create(b, memref.getLoc(), sourceType, res);
175  }
176  if (sourceType == compatibleMemRefType)
177  return res;
178  return memref::CastOp::create(b, memref.getLoc(), compatibleMemRefType, res);
179 }
180 
181 /// Operates under a scoped context to build the intersection between the
182 /// view `xferOp.getbase()` @ `xferOp.getIndices()` and the view `alloc`.
183 // TODO: view intersection/union/differences should be a proper std op.
184 static std::pair<Value, Value>
185 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
186  Value alloc) {
187  Location loc = xferOp.getLoc();
188  int64_t memrefRank = xferOp.getShapedType().getRank();
189  // TODO: relax this precondition, will require rank-reducing subviews.
190  assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
191  "Expected memref rank to match the alloc rank");
192  ValueRange leadingIndices =
193  xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
195  sizes.append(leadingIndices.begin(), leadingIndices.end());
196  auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
197  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
198  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
199  Value dimMemRef =
200  memref::DimOp::create(b, xferOp.getLoc(), xferOp.getBase(), indicesIdx);
201  Value dimAlloc = memref::DimOp::create(b, loc, alloc, resultIdx);
202  Value index = xferOp.getIndices()[indicesIdx];
203  AffineExpr i, j, k;
204  bindDims(xferOp.getContext(), i, j, k);
206  AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
207  // affine_min(%dimMemRef - %index, %dimAlloc)
208  Value affineMin =
209  affine::AffineMinOp::create(b, loc, index.getType(), maps[0],
210  ValueRange{dimMemRef, index, dimAlloc});
211  sizes.push_back(affineMin);
212  });
213 
214  SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
215  xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
216  SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
217  SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
218  auto copySrc = memref::SubViewOp::create(
219  b, loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
220  auto copyDest = memref::SubViewOp::create(
221  b, loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
222  return std::make_pair(copySrc, copyDest);
223 }
224 
225 /// Given an `xferOp` for which:
226 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
227 /// 2. a memref of single vector `alloc` has been allocated.
228 /// Produce IR resembling:
229 /// ```
230 /// %1:3 = scf.if (%inBounds) {
231 /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
232 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType
233 /// scf.yield %view, ... : compatibleMemRefType, index, index
234 /// } else {
235 /// %2 = linalg.fill(%pad, %alloc)
236 /// %3 = subview %view [...][...][...]
237 /// %4 = subview %alloc [0, 0] [...] [...]
238 /// linalg.copy(%3, %4)
239 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
240 /// scf.yield %5, ... : compatibleMemRefType, index, index
241 /// }
242 /// ```
243 /// Return the produced scf::IfOp.
244 static scf::IfOp
245 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
246  TypeRange returnTypes, Value inBoundsCond,
247  MemRefType compatibleMemRefType, Value alloc) {
248  Location loc = xferOp.getLoc();
249  Value zero = arith::ConstantIndexOp::create(b, loc, 0);
250  Value memref = xferOp.getBase();
251  return scf::IfOp::create(
252  b, loc, inBoundsCond,
253  [&](OpBuilder &b, Location loc) {
254  Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
255  scf::ValueVector viewAndIndices{res};
256  llvm::append_range(viewAndIndices, xferOp.getIndices());
257  scf::YieldOp::create(b, loc, viewAndIndices);
258  },
259  [&](OpBuilder &b, Location loc) {
260  linalg::FillOp::create(b, loc, ValueRange{xferOp.getPadding()},
261  ValueRange{alloc});
262  // Take partial subview of memref which guarantees no dimension
263  // overflows.
264  IRRewriter rewriter(b);
265  std::pair<Value, Value> copyArgs = createSubViewIntersection(
266  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
267  alloc);
268  memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
269  Value casted =
270  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
271  scf::ValueVector viewAndIndices{casted};
272  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
273  zero);
274  scf::YieldOp::create(b, loc, viewAndIndices);
275  });
276 }
277 
278 /// Given an `xferOp` for which:
279 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
280 /// 2. a memref of single vector `alloc` has been allocated.
281 /// Produce IR resembling:
282 /// ```
283 /// %1:3 = scf.if (%inBounds) {
284 /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
285 /// memref.cast %A: memref<A...> to compatibleMemRefType
286 /// scf.yield %view, ... : compatibleMemRefType, index, index
287 /// } else {
288 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
289 /// %3 = vector.type_cast %extra_alloc :
290 /// memref<...> to memref<vector<...>>
291 /// store %2, %3[] : memref<vector<...>>
292 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
293 /// scf.yield %4, ... : compatibleMemRefType, index, index
294 /// }
295 /// ```
296 /// Return the produced scf::IfOp.
298  RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
299  Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
300  Location loc = xferOp.getLoc();
301  scf::IfOp fullPartialIfOp;
302  Value zero = arith::ConstantIndexOp::create(b, loc, 0);
303  Value memref = xferOp.getBase();
304  return scf::IfOp::create(
305  b, loc, inBoundsCond,
306  [&](OpBuilder &b, Location loc) {
307  Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
308  scf::ValueVector viewAndIndices{res};
309  llvm::append_range(viewAndIndices, xferOp.getIndices());
310  scf::YieldOp::create(b, loc, viewAndIndices);
311  },
312  [&](OpBuilder &b, Location loc) {
313  Operation *newXfer = b.clone(*xferOp.getOperation());
314  Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
315  memref::StoreOp::create(
316  b, loc, vector,
317  vector::TypeCastOp::create(
318  b, loc, MemRefType::get({}, vector.getType()), alloc));
319 
320  Value casted =
321  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
322  scf::ValueVector viewAndIndices{casted};
323  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
324  zero);
325  scf::YieldOp::create(b, loc, viewAndIndices);
326  });
327 }
328 
329 /// Given an `xferOp` for which:
330 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
331 /// 2. a memref of single vector `alloc` has been allocated.
332 /// Produce IR resembling:
333 /// ```
334 /// %1:3 = scf.if (%inBounds) {
335 /// memref.cast %A: memref<A...> to compatibleMemRefType
336 /// scf.yield %view, ... : compatibleMemRefType, index, index
337 /// } else {
338 /// %3 = vector.type_cast %extra_alloc :
339 /// memref<...> to memref<vector<...>>
340 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
341 /// scf.yield %4, ... : compatibleMemRefType, index, index
342 /// }
343 /// ```
344 static ValueRange
345 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
346  TypeRange returnTypes, Value inBoundsCond,
347  MemRefType compatibleMemRefType, Value alloc) {
348  Location loc = xferOp.getLoc();
349  Value zero = arith::ConstantIndexOp::create(b, loc, 0);
350  Value memref = xferOp.getBase();
351  return scf::IfOp::create(
352  b, loc, inBoundsCond,
353  [&](OpBuilder &b, Location loc) {
354  Value res =
355  castToCompatibleMemRefType(b, memref, compatibleMemRefType);
356  scf::ValueVector viewAndIndices{res};
357  llvm::append_range(viewAndIndices, xferOp.getIndices());
358  scf::YieldOp::create(b, loc, viewAndIndices);
359  },
360  [&](OpBuilder &b, Location loc) {
361  Value casted =
362  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
363  scf::ValueVector viewAndIndices{casted};
364  viewAndIndices.insert(viewAndIndices.end(),
365  xferOp.getTransferRank(), zero);
366  scf::YieldOp::create(b, loc, viewAndIndices);
367  })
368  ->getResults();
369 }
370 
371 /// Given an `xferOp` for which:
372 /// 1. `inBoundsCond` has been computed.
373 /// 2. a memref of single vector `alloc` has been allocated.
374 /// 3. it originally wrote to %view
375 /// Produce IR resembling:
376 /// ```
377 /// %notInBounds = arith.xori %inBounds, %true
378 /// scf.if (%notInBounds) {
379 /// %3 = subview %alloc [...][...][...]
380 /// %4 = subview %view [0, 0][...][...]
381 /// linalg.copy(%3, %4)
382 /// }
383 /// ```
385  vector::TransferWriteOp xferOp,
386  Value inBoundsCond, Value alloc) {
387  Location loc = xferOp.getLoc();
388  auto notInBounds = arith::XOrIOp::create(
389  b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1));
390  scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) {
391  IRRewriter rewriter(b);
392  std::pair<Value, Value> copyArgs = createSubViewIntersection(
393  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
394  alloc);
395  memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
396  scf::YieldOp::create(b, loc, ValueRange{});
397  });
398 }
399 
400 /// Given an `xferOp` for which:
401 /// 1. `inBoundsCond` has been computed.
402 /// 2. a memref of single vector `alloc` has been allocated.
403 /// 3. it originally wrote to %view
404 /// Produce IR resembling:
405 /// ```
406 /// %notInBounds = arith.xori %inBounds, %true
407 /// scf.if (%notInBounds) {
408 /// %2 = load %alloc : memref<vector<...>>
409 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
410 /// }
411 /// ```
413  vector::TransferWriteOp xferOp,
414  Value inBoundsCond,
415  Value alloc) {
416  Location loc = xferOp.getLoc();
417  auto notInBounds = arith::XOrIOp::create(
418  b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1));
419  scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) {
420  IRMapping mapping;
421  Value load = memref::LoadOp::create(
422  b, loc,
423  vector::TypeCastOp::create(
424  b, loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
425  ValueRange());
426  mapping.map(xferOp.getVector(), load);
427  b.clone(*xferOp.getOperation(), mapping);
428  scf::YieldOp::create(b, loc, ValueRange{});
429  });
430 }
431 
432 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
434  // Find the closest surrounding allocation scope that is not a known looping
435  // construct (putting alloca's in loops doesn't always lower to deallocation
436  // until the end of the loop).
437  Operation *scope = nullptr;
438  for (Operation *parent = op->getParentOp(); parent != nullptr;
439  parent = parent->getParentOp()) {
440  if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
441  scope = parent;
442  if (!isa<scf::ForOp, affine::AffineForOp>(parent))
443  break;
444  }
445  assert(scope && "Expected op to be inside automatic allocation scope");
446  return scope;
447 }
448 
449 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
450 /// masking) fastpath and a slowpath.
451 ///
452 /// For vector.transfer_read:
453 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
454 /// newly created conditional upon function return.
455 /// To accomodate for the fact that the original vector.transfer indexing may be
456 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
457 /// scf.if op returns a view and values of type index.
458 ///
459 /// Example (a 2-D vector.transfer_read):
460 /// ```
461 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
462 /// ```
463 /// is transformed into:
464 /// ```
465 /// %1:3 = scf.if (%inBounds) {
466 /// // fastpath, direct cast
467 /// memref.cast %A: memref<A...> to compatibleMemRefType
468 /// scf.yield %view : compatibleMemRefType, index, index
469 /// } else {
470 /// // slowpath, not in-bounds vector.transfer or linalg.copy.
471 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
472 /// scf.yield %4 : compatibleMemRefType, index, index
473 // }
474 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
475 /// ```
476 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
477 ///
478 /// For vector.transfer_write:
479 /// There are 2 conditional blocks. First a block to decide which memref and
480 /// indices to use for an unmasked, inbounds write. Then a conditional block to
481 /// further copy a partial buffer into the final result in the slow path case.
482 ///
483 /// Example (a 2-D vector.transfer_write):
484 /// ```
485 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
486 /// ```
487 /// is transformed into:
488 /// ```
489 /// %1:3 = scf.if (%inBounds) {
490 /// memref.cast %A: memref<A...> to compatibleMemRefType
491 /// scf.yield %view : compatibleMemRefType, index, index
492 /// } else {
493 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
494 /// scf.yield %4 : compatibleMemRefType, index, index
495 /// }
496 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
497 /// true]}
498 /// scf.if (%notInBounds) {
499 /// // slowpath: not in-bounds vector.transfer or linalg.copy.
500 /// }
501 /// ```
502 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
503 ///
504 /// Preconditions:
505 /// 1. `xferOp.getPermutationMap()` must be a minor identity map
506 /// 2. the rank of the `xferOp.getBase()` and the rank of the
507 /// `xferOp.getVector()` must be equal. This will be relaxed in the future
508 /// but requires rank-reducing subviews.
510  RewriterBase &b, VectorTransferOpInterface xferOp,
511  VectorTransformsOptions options, scf::IfOp *ifOp) {
512  if (options.vectorTransferSplit == VectorTransferSplit::None)
513  return failure();
514 
515  SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
516  auto inBoundsAttr = b.getBoolArrayAttr(bools);
517  if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
518  b.modifyOpInPlace(xferOp, [&]() {
519  xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
520  });
521  return success();
522  }
523 
524  // Assert preconditions. Additionally, keep the variables in an inner scope to
525  // ensure they aren't used in the wrong scopes further down.
526  {
527  assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
528  "Expected splitFullAndPartialTransferPrecondition to hold");
529 
530  auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
531  auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
532 
533  if (!(xferReadOp || xferWriteOp))
534  return failure();
535  if (xferWriteOp && xferWriteOp.getMask())
536  return failure();
537  if (xferReadOp && xferReadOp.getMask())
538  return failure();
539  }
540 
542  b.setInsertionPoint(xferOp);
543  Value inBoundsCond = createInBoundsCond(
544  b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
545  if (!inBoundsCond)
546  return failure();
547 
548  // Top of the function `alloc` for transient storage.
549  Value alloc;
550  {
552  Operation *scope = getAutomaticAllocationScope(xferOp);
553  assert(scope->getNumRegions() == 1 &&
554  "AutomaticAllocationScope with >1 regions");
555  b.setInsertionPointToStart(&scope->getRegion(0).front());
556  auto shape = xferOp.getVectorType().getShape();
557  Type elementType = xferOp.getVectorType().getElementType();
558  alloc = memref::AllocaOp::create(b, scope->getLoc(),
559  MemRefType::get(shape, elementType),
560  ValueRange{}, b.getI64IntegerAttr(32));
561  }
562 
563  MemRefType compatibleMemRefType =
564  getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
565  cast<MemRefType>(alloc.getType()));
566  if (!compatibleMemRefType)
567  return failure();
568 
569  SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
570  b.getIndexType());
571  returnTypes[0] = compatibleMemRefType;
572 
573  if (auto xferReadOp =
574  dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
575  // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
576  scf::IfOp fullPartialIfOp =
577  options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
578  ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
579  inBoundsCond,
580  compatibleMemRefType, alloc)
581  : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
582  inBoundsCond, compatibleMemRefType,
583  alloc);
584  if (ifOp)
585  *ifOp = fullPartialIfOp;
586 
587  // Set existing read op to in-bounds, it always reads from a full buffer.
588  for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
589  xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
590 
591  b.modifyOpInPlace(xferOp, [&]() {
592  xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
593  });
594 
595  return success();
596  }
597 
598  auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
599 
600  // Decide which location to write the entire vector to.
601  auto memrefAndIndices = getLocationToWriteFullVec(
602  b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
603 
604  // Do an in bounds write to either the output or the extra allocated buffer.
605  // The operation is cloned to prevent deleting information needed for the
606  // later IR creation.
607  IRMapping mapping;
608  mapping.map(xferWriteOp.getBase(), memrefAndIndices.front());
609  mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
610  auto *clone = b.clone(*xferWriteOp, mapping);
611  clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
612 
613  // Create a potential copy from the allocated buffer to the final output in
614  // the slow path case.
615  if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
616  createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
617  else
618  createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
619 
620  b.eraseOp(xferOp);
621 
622  return success();
623 }
624 
625 namespace {
626 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
627 /// may take an extra filter to perform selection at a finer granularity.
628 struct VectorTransferFullPartialRewriter : public RewritePattern {
629  using FilterConstraintType =
630  std::function<LogicalResult(VectorTransferOpInterface op)>;
631 
632  explicit VectorTransferFullPartialRewriter(
633  MLIRContext *context,
635  FilterConstraintType filter =
636  [](VectorTransferOpInterface op) { return success(); },
637  PatternBenefit benefit = 1)
638  : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
639  filter(std::move(filter)) {}
640 
641  /// Performs the rewrite.
642  LogicalResult matchAndRewrite(Operation *op,
643  PatternRewriter &rewriter) const override;
644 
645 private:
647  FilterConstraintType filter;
648 };
649 
650 } // namespace
651 
652 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
653  Operation *op, PatternRewriter &rewriter) const {
654  auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
655  if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
656  failed(filter(xferOp)))
657  return failure();
658  return splitFullAndPartialTransfer(rewriter, xferOp, options);
659 }
660 
663  patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
664  options);
665 }
@ None
static llvm::ManagedStatic< PassManagerOptions > options
static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)
Given an xferOp for which:
static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fast path and a ...
static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)
Operates under a scoped context to build the intersection between the view xferOp....
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)
Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
static Operation * getAutomaticAllocationScope(Operation *op)
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, MemRefType compatibleMemRefType)
Casts the given memref to a compatible memref type.
Base type for affine expression.
Definition: AffineExpr.h:68
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:367
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:359
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:265
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
Block & front()
Definition: Region.h:65
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1327
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:68
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Structure to control the behavior of vector transform patterns.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.