MLIR  15.0.0git
VectorTransferSplitRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent patterns to rewrite a vector.transfer
10 // op into a fully in-bounds part and a partial part.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <type_traits>
15 
20 #include "mlir/Dialect/SCF/SCF.h"
22 
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
27 
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 
35 #define DEBUG_TYPE "vector-transfer-split"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 static Optional<int64_t> extractConstantIndex(Value v) {
41  if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
42  return cstOp.value();
43  if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
44  if (affineApplyOp.getAffineMap().isSingleConstant())
45  return affineApplyOp.getAffineMap().getSingleConstantResult();
46  return None;
47 }
48 
49 // Missing foldings of scf.if make it necessary to perform poor man's folding
50 // eagerly, especially in the case of unrolling. In the future, this should go
51 // away once scf.if folds properly.
53  auto maybeCstV = extractConstantIndex(v);
54  auto maybeCstUb = extractConstantIndex(ub);
55  if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
56  return Value();
57  return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
58 }
59 
60 /// Build the condition to ensure that a particular VectorTransferOpInterface
61 /// is in-bounds.
63  VectorTransferOpInterface xferOp) {
64  assert(xferOp.permutation_map().isMinorIdentity() &&
65  "Expected minor identity map");
66  Value inBoundsCond;
67  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
68  // Zip over the resulting vector shape and memref indices.
69  // If the dimension is known to be in-bounds, it does not participate in
70  // the construction of `inBoundsCond`.
71  if (xferOp.isDimInBounds(resultIdx))
72  return;
73  // Fold or create the check that `index + vector_size` <= `memref_size`.
74  Location loc = xferOp.getLoc();
75  int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
76  auto d0 = getAffineDimExpr(0, xferOp.getContext());
77  auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
78  Value sum =
79  makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
80  Value cond = createFoldedSLE(
81  b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
82  if (!cond)
83  return;
84  // Conjunction over all dims for which we are in-bounds.
85  if (inBoundsCond)
86  inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
87  else
88  inBoundsCond = cond;
89  });
90  return inBoundsCond;
91 }
92 
93 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
94 /// masking) fastpath and a slowpath.
95 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
96 /// newly created conditional upon function return.
97 /// To accomodate for the fact that the original vector.transfer indexing may be
98 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
99 /// scf.if op returns a view and values of type index.
100 /// At this time, only vector.transfer_read case is implemented.
101 ///
102 /// Example (a 2-D vector.transfer_read):
103 /// ```
104 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
105 /// ```
106 /// is transformed into:
107 /// ```
108 /// %1:3 = scf.if (%inBounds) {
109 /// // fastpath, direct cast
110 /// memref.cast %A: memref<A...> to compatibleMemRefType
111 /// scf.yield %view : compatibleMemRefType, index, index
112 /// } else {
113 /// // slowpath, not in-bounds vector.transfer or linalg.copy.
114 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
115 /// scf.yield %4 : compatibleMemRefType, index, index
116 // }
117 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
118 /// ```
119 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
120 ///
121 /// Preconditions:
122 /// 1. `xferOp.permutation_map()` must be a minor identity map
123 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
124 /// must be equal. This will be relaxed in the future but requires
125 /// rank-reducing subviews.
126 static LogicalResult
127 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
128  // TODO: support 0-d corner case.
129  if (xferOp.getTransferRank() == 0)
130  return failure();
131 
132  // TODO: expand support to these 2 cases.
133  if (!xferOp.permutation_map().isMinorIdentity())
134  return failure();
135  // Must have some out-of-bounds dimension to be a candidate for splitting.
136  if (!xferOp.hasOutOfBoundsDim())
137  return failure();
138  // Don't split transfer operations directly under IfOp, this avoids applying
139  // the pattern recursively.
140  // TODO: improve the filtering condition to make it more applicable.
141  if (isa<scf::IfOp>(xferOp->getParentOp()))
142  return failure();
143  return success();
144 }
145 
146 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
147 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
148 /// return null; otherwise:
149 /// 1. if `aT` and `bT` are cast-compatible, return `aT`.
150 /// 2. else return a new MemRefType obtained by iterating over the shape and
151 /// strides and:
152 /// a. keeping the ones that are static and equal across `aT` and `bT`.
153 /// b. using a dynamic shape and/or stride for the dimensions that don't
154 /// agree.
155 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
156  if (memref::CastOp::areCastCompatible(aT, bT))
157  return aT;
158  if (aT.getRank() != bT.getRank())
159  return MemRefType();
160  int64_t aOffset, bOffset;
161  SmallVector<int64_t, 4> aStrides, bStrides;
162  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
163  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
164  aStrides.size() != bStrides.size())
165  return MemRefType();
166 
167  ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
168  int64_t resOffset;
169  SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
170  resStrides(bT.getRank(), 0);
171  for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
172  resShape[idx] =
173  (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize;
174  resStrides[idx] = (aStrides[idx] == bStrides[idx])
175  ? aStrides[idx]
176  : ShapedType::kDynamicStrideOrOffset;
177  }
178  resOffset =
179  (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset;
180  return MemRefType::get(
181  resShape, aT.getElementType(),
182  makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
183 }
184 
185 /// Operates under a scoped context to build the intersection between the
186 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
187 // TODO: view intersection/union/differences should be a proper std op.
188 static std::pair<Value, Value>
189 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
190  Value alloc) {
191  Location loc = xferOp.getLoc();
192  int64_t memrefRank = xferOp.getShapedType().getRank();
193  // TODO: relax this precondition, will require rank-reducing subviews.
194  assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
195  "Expected memref rank to match the alloc rank");
196  ValueRange leadingIndices =
197  xferOp.indices().take_front(xferOp.getLeadingShapedRank());
199  sizes.append(leadingIndices.begin(), leadingIndices.end());
200  auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
201  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
202  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
203  Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
204  xferOp.source(), indicesIdx);
205  Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
206  Value index = xferOp.indices()[indicesIdx];
207  AffineExpr i, j, k;
208  bindDims(xferOp.getContext(), i, j, k);
210  AffineMap::inferFromExprList(MapList{{i - j, k}});
211  // affine_min(%dimMemRef - %index, %dimAlloc)
212  Value affineMin = b.create<AffineMinOp>(
213  loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
214  sizes.push_back(affineMin);
215  });
216 
217  SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
218  xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
219  SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
220  SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
221  auto copySrc = b.create<memref::SubViewOp>(
222  loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
223  auto copyDest = b.create<memref::SubViewOp>(
224  loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
225  return std::make_pair(copySrc, copyDest);
226 }
227 
228 /// Given an `xferOp` for which:
229 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
230 /// 2. a memref of single vector `alloc` has been allocated.
231 /// Produce IR resembling:
232 /// ```
233 /// %1:3 = scf.if (%inBounds) {
234 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType
235 /// scf.yield %view, ... : compatibleMemRefType, index, index
236 /// } else {
237 /// %2 = linalg.fill(%pad, %alloc)
238 /// %3 = subview %view [...][...][...]
239 /// %4 = subview %alloc [0, 0] [...] [...]
240 /// linalg.copy(%3, %4)
241 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
242 /// scf.yield %5, ... : compatibleMemRefType, index, index
243 /// }
244 /// ```
245 /// Return the produced scf::IfOp.
246 static scf::IfOp
247 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
248  TypeRange returnTypes, Value inBoundsCond,
249  MemRefType compatibleMemRefType, Value alloc) {
250  Location loc = xferOp.getLoc();
251  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
252  Value memref = xferOp.getSource();
253  return b.create<scf::IfOp>(
254  loc, returnTypes, inBoundsCond,
255  [&](OpBuilder &b, Location loc) {
256  Value res = memref;
257  if (compatibleMemRefType != xferOp.getShapedType())
258  res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
259  scf::ValueVector viewAndIndices{res};
260  viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
261  xferOp.getIndices().end());
262  b.create<scf::YieldOp>(loc, viewAndIndices);
263  },
264  [&](OpBuilder &b, Location loc) {
265  b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
266  ValueRange{alloc});
267  // Take partial subview of memref which guarantees no dimension
268  // overflows.
269  IRRewriter rewriter(b);
270  std::pair<Value, Value> copyArgs = createSubViewIntersection(
271  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
272  alloc);
273  b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
274  Value casted =
275  b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
276  scf::ValueVector viewAndIndices{casted};
277  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
278  zero);
279  b.create<scf::YieldOp>(loc, viewAndIndices);
280  });
281 }
282 
283 /// Given an `xferOp` for which:
284 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
285 /// 2. a memref of single vector `alloc` has been allocated.
286 /// Produce IR resembling:
287 /// ```
288 /// %1:3 = scf.if (%inBounds) {
289 /// memref.cast %A: memref<A...> to compatibleMemRefType
290 /// scf.yield %view, ... : compatibleMemRefType, index, index
291 /// } else {
292 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
293 /// %3 = vector.type_cast %extra_alloc :
294 /// memref<...> to memref<vector<...>>
295 /// store %2, %3[] : memref<vector<...>>
296 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
297 /// scf.yield %4, ... : compatibleMemRefType, index, index
298 /// }
299 /// ```
300 /// Return the produced scf::IfOp.
302  RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
303  Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
304  Location loc = xferOp.getLoc();
305  scf::IfOp fullPartialIfOp;
306  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
307  Value memref = xferOp.getSource();
308  return b.create<scf::IfOp>(
309  loc, returnTypes, inBoundsCond,
310  [&](OpBuilder &b, Location loc) {
311  Value res = memref;
312  if (compatibleMemRefType != xferOp.getShapedType())
313  res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
314  scf::ValueVector viewAndIndices{res};
315  viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
316  xferOp.getIndices().end());
317  b.create<scf::YieldOp>(loc, viewAndIndices);
318  },
319  [&](OpBuilder &b, Location loc) {
320  Operation *newXfer = b.clone(*xferOp.getOperation());
321  Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
322  b.create<memref::StoreOp>(
323  loc, vector,
324  b.create<vector::TypeCastOp>(
325  loc, MemRefType::get({}, vector.getType()), alloc));
326 
327  Value casted =
328  b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
329  scf::ValueVector viewAndIndices{casted};
330  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
331  zero);
332  b.create<scf::YieldOp>(loc, viewAndIndices);
333  });
334 }
335 
336 /// Given an `xferOp` for which:
337 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
338 /// 2. a memref of single vector `alloc` has been allocated.
339 /// Produce IR resembling:
340 /// ```
341 /// %1:3 = scf.if (%inBounds) {
342 /// memref.cast %A: memref<A...> to compatibleMemRefType
343 /// scf.yield %view, ... : compatibleMemRefType, index, index
344 /// } else {
345 /// %3 = vector.type_cast %extra_alloc :
346 /// memref<...> to memref<vector<...>>
347 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
348 /// scf.yield %4, ... : compatibleMemRefType, index, index
349 /// }
350 /// ```
351 static ValueRange
352 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
353  TypeRange returnTypes, Value inBoundsCond,
354  MemRefType compatibleMemRefType, Value alloc) {
355  Location loc = xferOp.getLoc();
356  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
357  Value memref = xferOp.getSource();
358  return b
359  .create<scf::IfOp>(
360  loc, returnTypes, inBoundsCond,
361  [&](OpBuilder &b, Location loc) {
362  Value res = memref;
363  if (compatibleMemRefType != xferOp.getShapedType())
364  res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
365  scf::ValueVector viewAndIndices{res};
366  viewAndIndices.insert(viewAndIndices.end(),
367  xferOp.getIndices().begin(),
368  xferOp.getIndices().end());
369  b.create<scf::YieldOp>(loc, viewAndIndices);
370  },
371  [&](OpBuilder &b, Location loc) {
372  Value casted =
373  b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
374  scf::ValueVector viewAndIndices{casted};
375  viewAndIndices.insert(viewAndIndices.end(),
376  xferOp.getTransferRank(), zero);
377  b.create<scf::YieldOp>(loc, viewAndIndices);
378  })
379  ->getResults();
380 }
381 
382 /// Given an `xferOp` for which:
383 /// 1. `inBoundsCond` has been computed.
384 /// 2. a memref of single vector `alloc` has been allocated.
385 /// 3. it originally wrote to %view
386 /// Produce IR resembling:
387 /// ```
388 /// %notInBounds = arith.xori %inBounds, %true
389 /// scf.if (%notInBounds) {
390 /// %3 = subview %alloc [...][...][...]
391 /// %4 = subview %view [0, 0][...][...]
392 /// linalg.copy(%3, %4)
393 /// }
394 /// ```
396  vector::TransferWriteOp xferOp,
397  Value inBoundsCond, Value alloc) {
398  Location loc = xferOp.getLoc();
399  auto notInBounds = b.create<arith::XOrIOp>(
400  loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
401  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
402  IRRewriter rewriter(b);
403  std::pair<Value, Value> copyArgs = createSubViewIntersection(
404  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
405  alloc);
406  b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
407  b.create<scf::YieldOp>(loc, ValueRange{});
408  });
409 }
410 
411 /// Given an `xferOp` for which:
412 /// 1. `inBoundsCond` has been computed.
413 /// 2. a memref of single vector `alloc` has been allocated.
414 /// 3. it originally wrote to %view
415 /// Produce IR resembling:
416 /// ```
417 /// %notInBounds = arith.xori %inBounds, %true
418 /// scf.if (%notInBounds) {
419 /// %2 = load %alloc : memref<vector<...>>
420 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
421 /// }
422 /// ```
424  vector::TransferWriteOp xferOp,
425  Value inBoundsCond,
426  Value alloc) {
427  Location loc = xferOp.getLoc();
428  auto notInBounds = b.create<arith::XOrIOp>(
429  loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
430  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
431  BlockAndValueMapping mapping;
432  Value load = b.create<memref::LoadOp>(
433  loc,
434  b.create<vector::TypeCastOp>(
435  loc, MemRefType::get({}, xferOp.getVector().getType()), alloc));
436  mapping.map(xferOp.getVector(), load);
437  b.clone(*xferOp.getOperation(), mapping);
438  b.create<scf::YieldOp>(loc, ValueRange{});
439  });
440 }
441 
442 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
444  // Find the closest surrounding allocation scope that is not a known looping
445  // construct (putting alloca's in loops doesn't always lower to deallocation
446  // until the end of the loop).
447  Operation *scope = nullptr;
448  for (Operation *parent = op->getParentOp(); parent != nullptr;
449  parent = parent->getParentOp()) {
450  if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
451  scope = parent;
452  if (!isa<scf::ForOp, AffineForOp>(parent))
453  break;
454  }
455  assert(scope && "Expected op to be inside automatic allocation scope");
456  return scope;
457 }
458 
459 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
460 /// masking) fastpath and a slowpath.
461 ///
462 /// For vector.transfer_read:
463 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
464 /// newly created conditional upon function return.
465 /// To accomodate for the fact that the original vector.transfer indexing may be
466 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
467 /// scf.if op returns a view and values of type index.
468 ///
469 /// Example (a 2-D vector.transfer_read):
470 /// ```
471 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
472 /// ```
473 /// is transformed into:
474 /// ```
475 /// %1:3 = scf.if (%inBounds) {
476 /// // fastpath, direct cast
477 /// memref.cast %A: memref<A...> to compatibleMemRefType
478 /// scf.yield %view : compatibleMemRefType, index, index
479 /// } else {
480 /// // slowpath, not in-bounds vector.transfer or linalg.copy.
481 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
482 /// scf.yield %4 : compatibleMemRefType, index, index
483 // }
484 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
485 /// ```
486 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
487 ///
488 /// For vector.transfer_write:
489 /// There are 2 conditional blocks. First a block to decide which memref and
490 /// indices to use for an unmasked, inbounds write. Then a conditional block to
491 /// further copy a partial buffer into the final result in the slow path case.
492 ///
493 /// Example (a 2-D vector.transfer_write):
494 /// ```
495 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
496 /// ```
497 /// is transformed into:
498 /// ```
499 /// %1:3 = scf.if (%inBounds) {
500 /// memref.cast %A: memref<A...> to compatibleMemRefType
501 /// scf.yield %view : compatibleMemRefType, index, index
502 /// } else {
503 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
504 /// scf.yield %4 : compatibleMemRefType, index, index
505 /// }
506 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
507 /// true]}
508 /// scf.if (%notInBounds) {
509 /// // slowpath: not in-bounds vector.transfer or linalg.copy.
510 /// }
511 /// ```
512 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
513 ///
514 /// Preconditions:
515 /// 1. `xferOp.permutation_map()` must be a minor identity map
516 /// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
517 /// must be equal. This will be relaxed in the future but requires
518 /// rank-reducing subviews.
520  RewriterBase &b, VectorTransferOpInterface xferOp,
521  VectorTransformsOptions options, scf::IfOp *ifOp) {
523  return failure();
524 
525  SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
526  auto inBoundsAttr = b.getBoolArrayAttr(bools);
528  xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
529  return success();
530  }
531 
532  // Assert preconditions. Additionally, keep the variables in an inner scope to
533  // ensure they aren't used in the wrong scopes further down.
534  {
536  "Expected splitFullAndPartialTransferPrecondition to hold");
537 
538  auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
539  auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
540 
541  if (!(xferReadOp || xferWriteOp))
542  return failure();
543  if (xferWriteOp && xferWriteOp.getMask())
544  return failure();
545  if (xferReadOp && xferReadOp.getMask())
546  return failure();
547  }
548 
550  b.setInsertionPoint(xferOp);
551  Value inBoundsCond = createInBoundsCond(
552  b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
553  if (!inBoundsCond)
554  return failure();
555 
556  // Top of the function `alloc` for transient storage.
557  Value alloc;
558  {
560  Operation *scope = getAutomaticAllocationScope(xferOp);
561  assert(scope->getNumRegions() == 1 &&
562  "AutomaticAllocationScope with >1 regions");
563  b.setInsertionPointToStart(&scope->getRegion(0).front());
564  auto shape = xferOp.getVectorType().getShape();
565  Type elementType = xferOp.getVectorType().getElementType();
566  alloc = b.create<memref::AllocaOp>(scope->getLoc(),
567  MemRefType::get(shape, elementType),
568  ValueRange{}, b.getI64IntegerAttr(32));
569  }
570 
571  MemRefType compatibleMemRefType =
572  getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
573  alloc.getType().cast<MemRefType>());
574  if (!compatibleMemRefType)
575  return failure();
576 
577  SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
578  b.getIndexType());
579  returnTypes[0] = compatibleMemRefType;
580 
581  if (auto xferReadOp =
582  dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
583  // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
584  scf::IfOp fullPartialIfOp =
586  ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
587  inBoundsCond,
588  compatibleMemRefType, alloc)
589  : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
590  inBoundsCond, compatibleMemRefType,
591  alloc);
592  if (ifOp)
593  *ifOp = fullPartialIfOp;
594 
595  // Set existing read op to in-bounds, it always reads from a full buffer.
596  for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
597  xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
598 
599  xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
600 
601  return success();
602  }
603 
604  auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
605 
606  // Decide which location to write the entire vector to.
607  auto memrefAndIndices = getLocationToWriteFullVec(
608  b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
609 
610  // Do an in bounds write to either the output or the extra allocated buffer.
611  // The operation is cloned to prevent deleting information needed for the
612  // later IR creation.
613  BlockAndValueMapping mapping;
614  mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
615  mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
616  auto *clone = b.clone(*xferWriteOp, mapping);
617  clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
618 
619  // Create a potential copy from the allocated buffer to the final output in
620  // the slow path case.
622  createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
623  else
624  createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
625 
626  xferOp->erase();
627 
628  return success();
629 }
630 
632  Operation *op, PatternRewriter &rewriter) const {
633  auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
634  if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
635  failed(filter(xferOp)))
636  return failure();
637  rewriter.startRootUpdate(xferOp);
638  if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
639  rewriter.finalizeRootUpdate(xferOp);
640  return success();
641  }
642  rewriter.cancelRootUpdate(xferOp);
643  return failure();
644 }
Include the generated interface declarations.
static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)
Operates under a scoped context to build the intersection between the view xferOp.source() @ xferOp.indices() and the view alloc.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:42
static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:514
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:475
static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)
Given an xferOp for which:
Block & front()
Definition: Region.h:65
VectorTransferSplit vectorTransferSplit
Option to control the splitting of vector transfers.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
A trait of region holding operations that define a new scope for automatic allocations, i.e., allocations that are freed when control is transferred back from the operation&#39;s region.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:229
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:468
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
std::vector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:55
static Optional< int64_t > extractConstantIndex(Value v)
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:708
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:484
static Operation * getAutomaticAllocationScope(Operation *op)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Split using in-bounds + out-of-bounds vector.transfer operations.
Operation * clone(BlockAndValueMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:564
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
static Value createFoldedSLE(RewriterBase &b, Value v, Value ub)
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:172
Base type for affine expression.
Definition: AffineExpr.h:68
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:161
static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
Eliminates identifier at the specified position using Fourier-Motzkin variable elimination.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:489
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:209
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:584
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source...
Definition: VectorUtils.cpp:37
Do not split vector transfer operation but instead mark it as "in-bounds".
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static llvm::ManagedStatic< PassManagerOptions > options
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:402
Type getType() const
Return the type of this value.
Definition: Value.h:118
IndexType getIndexType()
Definition: Builders.cpp:48
Do not split vector transfer operations.
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:79
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)
Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Structure to control the behavior of vector transform patterns.
AffineMap makeStridedLinearLayoutMap(ArrayRef< int64_t > strides, int64_t offset, MLIRContext *context)
Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() represents a dynamic value)...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:328
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
Definition: PatternMatch.h:493
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:484
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Performs the rewrite.
U cast() const
Definition: Types.h:250