MLIR  21.0.0git
VectorTransferSplitRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent patterns to rewrite a vector.transfer
10 // op into a fully in-bounds part and a partial part.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <optional>
15 #include <type_traits>
16 
23 
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
28 
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 #define DEBUG_TYPE "vector-transfer-split"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 /// Build the condition to ensure that a particular VectorTransferOpInterface
42 /// is in-bounds.
44  VectorTransferOpInterface xferOp) {
45  assert(xferOp.getPermutationMap().isMinorIdentity() &&
46  "Expected minor identity map");
47  Value inBoundsCond;
48  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
49  // Zip over the resulting vector shape and memref indices.
50  // If the dimension is known to be in-bounds, it does not participate in
51  // the construction of `inBoundsCond`.
52  if (xferOp.isDimInBounds(resultIdx))
53  return;
54  // Fold or create the check that `index + vector_size` <= `memref_size`.
55  Location loc = xferOp.getLoc();
56  int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
58  b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
59  {xferOp.getIndices()[indicesIdx]});
60  OpFoldResult dimSz =
61  memref::getMixedSize(b, loc, xferOp.getSource(), indicesIdx);
62  auto maybeCstSum = getConstantIntValue(sum);
63  auto maybeCstDimSz = getConstantIntValue(dimSz);
64  if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
65  return;
66  Value cond =
67  b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
69  getValueOrCreateConstantIndexOp(b, loc, dimSz));
70  // Conjunction over all dims for which we are in-bounds.
71  if (inBoundsCond)
72  inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
73  else
74  inBoundsCond = cond;
75  });
76  return inBoundsCond;
77 }
78 
79 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
80 /// masking) fast path and a slow path.
81 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
82 /// newly created conditional upon function return.
83 /// To accommodate for the fact that the original vector.transfer indexing may
84 /// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
85 /// scf.if op returns a view and values of type index.
86 /// At this time, only vector.transfer_read case is implemented.
87 ///
88 /// Example (a 2-D vector.transfer_read):
89 /// ```
90 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
91 /// ```
92 /// is transformed into:
93 /// ```
94 /// %1:3 = scf.if (%inBounds) {
95 /// // fast path, direct cast
96 /// memref.cast %A: memref<A...> to compatibleMemRefType
97 /// scf.yield %view : compatibleMemRefType, index, index
98 /// } else {
99 /// // slow path, not in-bounds vector.transfer or linalg.copy.
100 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
101 /// scf.yield %4 : compatibleMemRefType, index, index
102 // }
103 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
104 /// ```
105 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
106 ///
107 /// Preconditions:
108 /// 1. `xferOp.getPermutationMap()` must be a minor identity map
109 /// 2. the rank of the `xferOp.memref()` and the rank of the
110 /// `xferOp.getVector()` must be equal. This will be relaxed in the future
111 /// but requires rank-reducing subviews.
112 static LogicalResult
113 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
114  // TODO: support 0-d corner case.
115  if (xferOp.getTransferRank() == 0)
116  return failure();
117 
118  // TODO: expand support to these 2 cases.
119  if (!xferOp.getPermutationMap().isMinorIdentity())
120  return failure();
121  // Must have some out-of-bounds dimension to be a candidate for splitting.
122  if (!xferOp.hasOutOfBoundsDim())
123  return failure();
124  // Don't split transfer operations directly under IfOp, this avoids applying
125  // the pattern recursively.
126  // TODO: improve the filtering condition to make it more applicable.
127  if (isa<scf::IfOp>(xferOp->getParentOp()))
128  return failure();
129  return success();
130 }
131 
132 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
133 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
134 /// return null; otherwise:
135 /// 1. if `aT` and `bT` are cast-compatible, return `aT`.
136 /// 2. else return a new MemRefType obtained by iterating over the shape and
137 /// strides and:
138 /// a. keeping the ones that are static and equal across `aT` and `bT`.
139 /// b. using a dynamic shape and/or stride for the dimensions that don't
140 /// agree.
141 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
142  if (memref::CastOp::areCastCompatible(aT, bT))
143  return aT;
144  if (aT.getRank() != bT.getRank())
145  return MemRefType();
146  int64_t aOffset, bOffset;
147  SmallVector<int64_t, 4> aStrides, bStrides;
148  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
149  failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
150  aStrides.size() != bStrides.size())
151  return MemRefType();
152 
153  ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
154  int64_t resOffset;
155  SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
156  resStrides(bT.getRank(), 0);
157  for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
158  resShape[idx] =
159  (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
160  resStrides[idx] =
161  (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
162  }
163  resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
164  return MemRefType::get(
165  resShape, aT.getElementType(),
166  StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
167 }
168 
169 /// Casts the given memref to a compatible memref type. If the source memref has
170 /// a different address space than the target type, a `memref.memory_space_cast`
171 /// is first inserted, followed by a `memref.cast`.
173  MemRefType compatibleMemRefType) {
174  MemRefType sourceType = cast<MemRefType>(memref.getType());
175  Value res = memref;
176  if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
177  sourceType = MemRefType::get(
178  sourceType.getShape(), sourceType.getElementType(),
179  sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
180  res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
181  }
182  if (sourceType == compatibleMemRefType)
183  return res;
184  return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
185 }
186 
187 /// Operates under a scoped context to build the intersection between the
188 /// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`.
189 // TODO: view intersection/union/differences should be a proper std op.
190 static std::pair<Value, Value>
191 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
192  Value alloc) {
193  Location loc = xferOp.getLoc();
194  int64_t memrefRank = xferOp.getShapedType().getRank();
195  // TODO: relax this precondition, will require rank-reducing subviews.
196  assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
197  "Expected memref rank to match the alloc rank");
198  ValueRange leadingIndices =
199  xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
201  sizes.append(leadingIndices.begin(), leadingIndices.end());
202  auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
203  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
204  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
205  Value dimMemRef = b.create<memref::DimOp>(xferOp.getLoc(),
206  xferOp.getSource(), indicesIdx);
207  Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
208  Value index = xferOp.getIndices()[indicesIdx];
209  AffineExpr i, j, k;
210  bindDims(xferOp.getContext(), i, j, k);
212  AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
213  // affine_min(%dimMemRef - %index, %dimAlloc)
214  Value affineMin = b.create<affine::AffineMinOp>(
215  loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
216  sizes.push_back(affineMin);
217  });
218 
219  SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
220  xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
221  SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
222  SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
223  auto copySrc = b.create<memref::SubViewOp>(
224  loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides);
225  auto copyDest = b.create<memref::SubViewOp>(
226  loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides);
227  return std::make_pair(copySrc, copyDest);
228 }
229 
230 /// Given an `xferOp` for which:
231 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
232 /// 2. a memref of single vector `alloc` has been allocated.
233 /// Produce IR resembling:
234 /// ```
235 /// %1:3 = scf.if (%inBounds) {
236 /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
237 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType
238 /// scf.yield %view, ... : compatibleMemRefType, index, index
239 /// } else {
240 /// %2 = linalg.fill(%pad, %alloc)
241 /// %3 = subview %view [...][...][...]
242 /// %4 = subview %alloc [0, 0] [...] [...]
243 /// linalg.copy(%3, %4)
244 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
245 /// scf.yield %5, ... : compatibleMemRefType, index, index
246 /// }
247 /// ```
248 /// Return the produced scf::IfOp.
249 static scf::IfOp
250 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
251  TypeRange returnTypes, Value inBoundsCond,
252  MemRefType compatibleMemRefType, Value alloc) {
253  Location loc = xferOp.getLoc();
254  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
255  Value memref = xferOp.getSource();
256  return b.create<scf::IfOp>(
257  loc, inBoundsCond,
258  [&](OpBuilder &b, Location loc) {
259  Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
260  scf::ValueVector viewAndIndices{res};
261  llvm::append_range(viewAndIndices, xferOp.getIndices());
262  b.create<scf::YieldOp>(loc, viewAndIndices);
263  },
264  [&](OpBuilder &b, Location loc) {
265  b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
266  ValueRange{alloc});
267  // Take partial subview of memref which guarantees no dimension
268  // overflows.
269  IRRewriter rewriter(b);
270  std::pair<Value, Value> copyArgs = createSubViewIntersection(
271  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
272  alloc);
273  b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
274  Value casted =
275  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
276  scf::ValueVector viewAndIndices{casted};
277  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
278  zero);
279  b.create<scf::YieldOp>(loc, viewAndIndices);
280  });
281 }
282 
283 /// Given an `xferOp` for which:
284 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
285 /// 2. a memref of single vector `alloc` has been allocated.
286 /// Produce IR resembling:
287 /// ```
288 /// %1:3 = scf.if (%inBounds) {
289 /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
290 /// memref.cast %A: memref<A...> to compatibleMemRefType
291 /// scf.yield %view, ... : compatibleMemRefType, index, index
292 /// } else {
293 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
294 /// %3 = vector.type_cast %extra_alloc :
295 /// memref<...> to memref<vector<...>>
296 /// store %2, %3[] : memref<vector<...>>
297 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
298 /// scf.yield %4, ... : compatibleMemRefType, index, index
299 /// }
300 /// ```
301 /// Return the produced scf::IfOp.
303  RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
304  Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
305  Location loc = xferOp.getLoc();
306  scf::IfOp fullPartialIfOp;
307  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
308  Value memref = xferOp.getSource();
309  return b.create<scf::IfOp>(
310  loc, inBoundsCond,
311  [&](OpBuilder &b, Location loc) {
312  Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
313  scf::ValueVector viewAndIndices{res};
314  llvm::append_range(viewAndIndices, xferOp.getIndices());
315  b.create<scf::YieldOp>(loc, viewAndIndices);
316  },
317  [&](OpBuilder &b, Location loc) {
318  Operation *newXfer = b.clone(*xferOp.getOperation());
319  Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
320  b.create<memref::StoreOp>(
321  loc, vector,
322  b.create<vector::TypeCastOp>(
323  loc, MemRefType::get({}, vector.getType()), alloc));
324 
325  Value casted =
326  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
327  scf::ValueVector viewAndIndices{casted};
328  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
329  zero);
330  b.create<scf::YieldOp>(loc, viewAndIndices);
331  });
332 }
333 
334 /// Given an `xferOp` for which:
335 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
336 /// 2. a memref of single vector `alloc` has been allocated.
337 /// Produce IR resembling:
338 /// ```
339 /// %1:3 = scf.if (%inBounds) {
340 /// memref.cast %A: memref<A...> to compatibleMemRefType
341 /// scf.yield %view, ... : compatibleMemRefType, index, index
342 /// } else {
343 /// %3 = vector.type_cast %extra_alloc :
344 /// memref<...> to memref<vector<...>>
345 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
346 /// scf.yield %4, ... : compatibleMemRefType, index, index
347 /// }
348 /// ```
349 static ValueRange
350 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
351  TypeRange returnTypes, Value inBoundsCond,
352  MemRefType compatibleMemRefType, Value alloc) {
353  Location loc = xferOp.getLoc();
354  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
355  Value memref = xferOp.getSource();
356  return b
357  .create<scf::IfOp>(
358  loc, inBoundsCond,
359  [&](OpBuilder &b, Location loc) {
360  Value res =
361  castToCompatibleMemRefType(b, memref, compatibleMemRefType);
362  scf::ValueVector viewAndIndices{res};
363  llvm::append_range(viewAndIndices, xferOp.getIndices());
364  b.create<scf::YieldOp>(loc, viewAndIndices);
365  },
366  [&](OpBuilder &b, Location loc) {
367  Value casted =
368  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
369  scf::ValueVector viewAndIndices{casted};
370  viewAndIndices.insert(viewAndIndices.end(),
371  xferOp.getTransferRank(), zero);
372  b.create<scf::YieldOp>(loc, viewAndIndices);
373  })
374  ->getResults();
375 }
376 
377 /// Given an `xferOp` for which:
378 /// 1. `inBoundsCond` has been computed.
379 /// 2. a memref of single vector `alloc` has been allocated.
380 /// 3. it originally wrote to %view
381 /// Produce IR resembling:
382 /// ```
383 /// %notInBounds = arith.xori %inBounds, %true
384 /// scf.if (%notInBounds) {
385 /// %3 = subview %alloc [...][...][...]
386 /// %4 = subview %view [0, 0][...][...]
387 /// linalg.copy(%3, %4)
388 /// }
389 /// ```
391  vector::TransferWriteOp xferOp,
392  Value inBoundsCond, Value alloc) {
393  Location loc = xferOp.getLoc();
394  auto notInBounds = b.create<arith::XOrIOp>(
395  loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
396  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
397  IRRewriter rewriter(b);
398  std::pair<Value, Value> copyArgs = createSubViewIntersection(
399  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
400  alloc);
401  b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
402  b.create<scf::YieldOp>(loc, ValueRange{});
403  });
404 }
405 
406 /// Given an `xferOp` for which:
407 /// 1. `inBoundsCond` has been computed.
408 /// 2. a memref of single vector `alloc` has been allocated.
409 /// 3. it originally wrote to %view
410 /// Produce IR resembling:
411 /// ```
412 /// %notInBounds = arith.xori %inBounds, %true
413 /// scf.if (%notInBounds) {
414 /// %2 = load %alloc : memref<vector<...>>
415 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
416 /// }
417 /// ```
419  vector::TransferWriteOp xferOp,
420  Value inBoundsCond,
421  Value alloc) {
422  Location loc = xferOp.getLoc();
423  auto notInBounds = b.create<arith::XOrIOp>(
424  loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
425  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
426  IRMapping mapping;
427  Value load = b.create<memref::LoadOp>(
428  loc,
429  b.create<vector::TypeCastOp>(
430  loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
431  ValueRange());
432  mapping.map(xferOp.getVector(), load);
433  b.clone(*xferOp.getOperation(), mapping);
434  b.create<scf::YieldOp>(loc, ValueRange{});
435  });
436 }
437 
438 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
440  // Find the closest surrounding allocation scope that is not a known looping
441  // construct (putting alloca's in loops doesn't always lower to deallocation
442  // until the end of the loop).
443  Operation *scope = nullptr;
444  for (Operation *parent = op->getParentOp(); parent != nullptr;
445  parent = parent->getParentOp()) {
446  if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
447  scope = parent;
448  if (!isa<scf::ForOp, affine::AffineForOp>(parent))
449  break;
450  }
451  assert(scope && "Expected op to be inside automatic allocation scope");
452  return scope;
453 }
454 
455 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
456 /// masking) fastpath and a slowpath.
457 ///
458 /// For vector.transfer_read:
459 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
460 /// newly created conditional upon function return.
461 /// To accomodate for the fact that the original vector.transfer indexing may be
462 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
463 /// scf.if op returns a view and values of type index.
464 ///
465 /// Example (a 2-D vector.transfer_read):
466 /// ```
467 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
468 /// ```
469 /// is transformed into:
470 /// ```
471 /// %1:3 = scf.if (%inBounds) {
472 /// // fastpath, direct cast
473 /// memref.cast %A: memref<A...> to compatibleMemRefType
474 /// scf.yield %view : compatibleMemRefType, index, index
475 /// } else {
476 /// // slowpath, not in-bounds vector.transfer or linalg.copy.
477 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
478 /// scf.yield %4 : compatibleMemRefType, index, index
479 // }
480 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
481 /// ```
482 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
483 ///
484 /// For vector.transfer_write:
485 /// There are 2 conditional blocks. First a block to decide which memref and
486 /// indices to use for an unmasked, inbounds write. Then a conditional block to
487 /// further copy a partial buffer into the final result in the slow path case.
488 ///
489 /// Example (a 2-D vector.transfer_write):
490 /// ```
491 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
492 /// ```
493 /// is transformed into:
494 /// ```
495 /// %1:3 = scf.if (%inBounds) {
496 /// memref.cast %A: memref<A...> to compatibleMemRefType
497 /// scf.yield %view : compatibleMemRefType, index, index
498 /// } else {
499 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
500 /// scf.yield %4 : compatibleMemRefType, index, index
501 /// }
502 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
503 /// true]}
504 /// scf.if (%notInBounds) {
505 /// // slowpath: not in-bounds vector.transfer or linalg.copy.
506 /// }
507 /// ```
508 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
509 ///
510 /// Preconditions:
511 /// 1. `xferOp.getPermutationMap()` must be a minor identity map
512 /// 2. the rank of the `xferOp.getSource()` and the rank of the
513 /// `xferOp.getVector()` must be equal. This will be relaxed in the future
514 /// but requires rank-reducing subviews.
516  RewriterBase &b, VectorTransferOpInterface xferOp,
517  VectorTransformsOptions options, scf::IfOp *ifOp) {
518  if (options.vectorTransferSplit == VectorTransferSplit::None)
519  return failure();
520 
521  SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
522  auto inBoundsAttr = b.getBoolArrayAttr(bools);
523  if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
524  b.modifyOpInPlace(xferOp, [&]() {
525  xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
526  });
527  return success();
528  }
529 
530  // Assert preconditions. Additionally, keep the variables in an inner scope to
531  // ensure they aren't used in the wrong scopes further down.
532  {
533  assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
534  "Expected splitFullAndPartialTransferPrecondition to hold");
535 
536  auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
537  auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
538 
539  if (!(xferReadOp || xferWriteOp))
540  return failure();
541  if (xferWriteOp && xferWriteOp.getMask())
542  return failure();
543  if (xferReadOp && xferReadOp.getMask())
544  return failure();
545  }
546 
548  b.setInsertionPoint(xferOp);
549  Value inBoundsCond = createInBoundsCond(
550  b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
551  if (!inBoundsCond)
552  return failure();
553 
554  // Top of the function `alloc` for transient storage.
555  Value alloc;
556  {
558  Operation *scope = getAutomaticAllocationScope(xferOp);
559  assert(scope->getNumRegions() == 1 &&
560  "AutomaticAllocationScope with >1 regions");
561  b.setInsertionPointToStart(&scope->getRegion(0).front());
562  auto shape = xferOp.getVectorType().getShape();
563  Type elementType = xferOp.getVectorType().getElementType();
564  alloc = b.create<memref::AllocaOp>(scope->getLoc(),
565  MemRefType::get(shape, elementType),
566  ValueRange{}, b.getI64IntegerAttr(32));
567  }
568 
569  MemRefType compatibleMemRefType =
570  getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
571  cast<MemRefType>(alloc.getType()));
572  if (!compatibleMemRefType)
573  return failure();
574 
575  SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
576  b.getIndexType());
577  returnTypes[0] = compatibleMemRefType;
578 
579  if (auto xferReadOp =
580  dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
581  // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
582  scf::IfOp fullPartialIfOp =
583  options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
584  ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
585  inBoundsCond,
586  compatibleMemRefType, alloc)
587  : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
588  inBoundsCond, compatibleMemRefType,
589  alloc);
590  if (ifOp)
591  *ifOp = fullPartialIfOp;
592 
593  // Set existing read op to in-bounds, it always reads from a full buffer.
594  for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
595  xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
596 
597  b.modifyOpInPlace(xferOp, [&]() {
598  xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
599  });
600 
601  return success();
602  }
603 
604  auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
605 
606  // Decide which location to write the entire vector to.
607  auto memrefAndIndices = getLocationToWriteFullVec(
608  b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
609 
610  // Do an in bounds write to either the output or the extra allocated buffer.
611  // The operation is cloned to prevent deleting information needed for the
612  // later IR creation.
613  IRMapping mapping;
614  mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
615  mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
616  auto *clone = b.clone(*xferWriteOp, mapping);
617  clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
618 
619  // Create a potential copy from the allocated buffer to the final output in
620  // the slow path case.
621  if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
622  createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
623  else
624  createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
625 
626  b.eraseOp(xferOp);
627 
628  return success();
629 }
630 
631 namespace {
632 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
633 /// may take an extra filter to perform selection at a finer granularity.
634 struct VectorTransferFullPartialRewriter : public RewritePattern {
635  using FilterConstraintType =
636  std::function<LogicalResult(VectorTransferOpInterface op)>;
637 
638  explicit VectorTransferFullPartialRewriter(
639  MLIRContext *context,
641  FilterConstraintType filter =
642  [](VectorTransferOpInterface op) { return success(); },
643  PatternBenefit benefit = 1)
644  : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
645  filter(std::move(filter)) {}
646 
647  /// Performs the rewrite.
648  LogicalResult matchAndRewrite(Operation *op,
649  PatternRewriter &rewriter) const override;
650 
651 private:
653  FilterConstraintType filter;
654 };
655 
656 } // namespace
657 
658 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
659  Operation *op, PatternRewriter &rewriter) const {
660  auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
661  if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
662  failed(filter(xferOp)))
663  return failure();
664  return splitFullAndPartialTransfer(rewriter, xferOp, options);
665 }
666 
669  patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
670  options);
671 }
@ None
static llvm::ManagedStatic< PassManagerOptions > options
static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)
Given an xferOp for which:
static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fast path and a ...
static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)
Operates under a scoped context to build the intersection between the view xferOp....
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)
Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
static Operation * getAutomaticAllocationScope(Operation *op)
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, MemRefType compatibleMemRefType)
Casts the given memref to a compatible memref type.
Base type for affine expression.
Definition: AffineExpr.h:68
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:368
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:266
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:734
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
Block & front()
Definition: Region.h:65
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:362
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:598
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1208
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:68
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Structure to control the behavior of vector transform patterns.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.