MLIR  20.0.0git
VectorTransferSplitRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent patterns to rewrite a vector.transfer
10 // op into a fully in-bounds part and a partial part.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <optional>
15 #include <type_traits>
16 
23 
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
28 
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 #define DEBUG_TYPE "vector-transfer-split"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 /// Build the condition to ensure that a particular VectorTransferOpInterface
42 /// is in-bounds.
44  VectorTransferOpInterface xferOp) {
45  assert(xferOp.getPermutationMap().isMinorIdentity() &&
46  "Expected minor identity map");
47  Value inBoundsCond;
48  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
49  // Zip over the resulting vector shape and memref indices.
50  // If the dimension is known to be in-bounds, it does not participate in
51  // the construction of `inBoundsCond`.
52  if (xferOp.isDimInBounds(resultIdx))
53  return;
54  // Fold or create the check that `index + vector_size` <= `memref_size`.
55  Location loc = xferOp.getLoc();
56  int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
58  b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
59  {xferOp.getIndices()[indicesIdx]});
60  OpFoldResult dimSz =
61  memref::getMixedSize(b, loc, xferOp.getSource(), indicesIdx);
62  auto maybeCstSum = getConstantIntValue(sum);
63  auto maybeCstDimSz = getConstantIntValue(dimSz);
64  if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
65  return;
66  Value cond =
67  b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
69  getValueOrCreateConstantIndexOp(b, loc, dimSz));
70  // Conjunction over all dims for which we are in-bounds.
71  if (inBoundsCond)
72  inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
73  else
74  inBoundsCond = cond;
75  });
76  return inBoundsCond;
77 }
78 
79 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
80 /// masking) fast path and a slow path.
81 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
82 /// newly created conditional upon function return.
83 /// To accommodate for the fact that the original vector.transfer indexing may
84 /// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
85 /// scf.if op returns a view and values of type index.
86 /// At this time, only vector.transfer_read case is implemented.
87 ///
88 /// Example (a 2-D vector.transfer_read):
89 /// ```
90 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
91 /// ```
92 /// is transformed into:
93 /// ```
94 /// %1:3 = scf.if (%inBounds) {
95 /// // fast path, direct cast
96 /// memref.cast %A: memref<A...> to compatibleMemRefType
97 /// scf.yield %view : compatibleMemRefType, index, index
98 /// } else {
99 /// // slow path, not in-bounds vector.transfer or linalg.copy.
100 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
101 /// scf.yield %4 : compatibleMemRefType, index, index
102 // }
103 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
104 /// ```
105 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
106 ///
107 /// Preconditions:
108 /// 1. `xferOp.getPermutationMap()` must be a minor identity map
109 /// 2. the rank of the `xferOp.memref()` and the rank of the
110 /// `xferOp.getVector()` must be equal. This will be relaxed in the future
111 /// but requires rank-reducing subviews.
112 static LogicalResult
113 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
114  // TODO: support 0-d corner case.
115  if (xferOp.getTransferRank() == 0)
116  return failure();
117 
118  // TODO: expand support to these 2 cases.
119  if (!xferOp.getPermutationMap().isMinorIdentity())
120  return failure();
121  // Must have some out-of-bounds dimension to be a candidate for splitting.
122  if (!xferOp.hasOutOfBoundsDim())
123  return failure();
124  // Don't split transfer operations directly under IfOp, this avoids applying
125  // the pattern recursively.
126  // TODO: improve the filtering condition to make it more applicable.
127  if (isa<scf::IfOp>(xferOp->getParentOp()))
128  return failure();
129  return success();
130 }
131 
132 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
133 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
134 /// return null; otherwise:
135 /// 1. if `aT` and `bT` are cast-compatible, return `aT`.
136 /// 2. else return a new MemRefType obtained by iterating over the shape and
137 /// strides and:
138 /// a. keeping the ones that are static and equal across `aT` and `bT`.
139 /// b. using a dynamic shape and/or stride for the dimensions that don't
140 /// agree.
141 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
142  if (memref::CastOp::areCastCompatible(aT, bT))
143  return aT;
144  if (aT.getRank() != bT.getRank())
145  return MemRefType();
146  int64_t aOffset, bOffset;
147  SmallVector<int64_t, 4> aStrides, bStrides;
148  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
149  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
150  aStrides.size() != bStrides.size())
151  return MemRefType();
152 
153  ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
154  int64_t resOffset;
155  SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
156  resStrides(bT.getRank(), 0);
157  for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
158  resShape[idx] =
159  (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
160  resStrides[idx] =
161  (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
162  }
163  resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
164  return MemRefType::get(
165  resShape, aT.getElementType(),
166  StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
167 }
168 
169 /// Casts the given memref to a compatible memref type. If the source memref has
170 /// a different address space than the target type, a `memref.memory_space_cast`
171 /// is first inserted, followed by a `memref.cast`.
173  MemRefType compatibleMemRefType) {
174  MemRefType sourceType = cast<MemRefType>(memref.getType());
175  Value res = memref;
176  if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
177  sourceType = MemRefType::get(
178  sourceType.getShape(), sourceType.getElementType(),
179  sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
180  res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
181  }
182  if (sourceType == compatibleMemRefType)
183  return res;
184  return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
185 }
186 
187 /// Operates under a scoped context to build the intersection between the
188 /// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`.
189 // TODO: view intersection/union/differences should be a proper std op.
190 static std::pair<Value, Value>
191 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
192  Value alloc) {
193  Location loc = xferOp.getLoc();
194  int64_t memrefRank = xferOp.getShapedType().getRank();
195  // TODO: relax this precondition, will require rank-reducing subviews.
196  assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
197  "Expected memref rank to match the alloc rank");
198  ValueRange leadingIndices =
199  xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
201  sizes.append(leadingIndices.begin(), leadingIndices.end());
202  auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
203  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
204  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
205  Value dimMemRef = b.create<memref::DimOp>(xferOp.getLoc(),
206  xferOp.getSource(), indicesIdx);
207  Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
208  Value index = xferOp.getIndices()[indicesIdx];
209  AffineExpr i, j, k;
210  bindDims(xferOp.getContext(), i, j, k);
212  AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
213  // affine_min(%dimMemRef - %index, %dimAlloc)
214  Value affineMin = b.create<affine::AffineMinOp>(
215  loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
216  sizes.push_back(affineMin);
217  });
218 
219  SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
220  xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
221  SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
222  SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
223  auto copySrc = b.create<memref::SubViewOp>(
224  loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides);
225  auto copyDest = b.create<memref::SubViewOp>(
226  loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides);
227  return std::make_pair(copySrc, copyDest);
228 }
229 
230 /// Given an `xferOp` for which:
231 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
232 /// 2. a memref of single vector `alloc` has been allocated.
233 /// Produce IR resembling:
234 /// ```
235 /// %1:3 = scf.if (%inBounds) {
236 /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
237 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType
238 /// scf.yield %view, ... : compatibleMemRefType, index, index
239 /// } else {
240 /// %2 = linalg.fill(%pad, %alloc)
241 /// %3 = subview %view [...][...][...]
242 /// %4 = subview %alloc [0, 0] [...] [...]
243 /// linalg.copy(%3, %4)
244 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
245 /// scf.yield %5, ... : compatibleMemRefType, index, index
246 /// }
247 /// ```
248 /// Return the produced scf::IfOp.
249 static scf::IfOp
250 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
251  TypeRange returnTypes, Value inBoundsCond,
252  MemRefType compatibleMemRefType, Value alloc) {
253  Location loc = xferOp.getLoc();
254  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
255  Value memref = xferOp.getSource();
256  return b.create<scf::IfOp>(
257  loc, inBoundsCond,
258  [&](OpBuilder &b, Location loc) {
259  Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
260  scf::ValueVector viewAndIndices{res};
261  viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
262  xferOp.getIndices().end());
263  b.create<scf::YieldOp>(loc, viewAndIndices);
264  },
265  [&](OpBuilder &b, Location loc) {
266  b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
267  ValueRange{alloc});
268  // Take partial subview of memref which guarantees no dimension
269  // overflows.
270  IRRewriter rewriter(b);
271  std::pair<Value, Value> copyArgs = createSubViewIntersection(
272  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
273  alloc);
274  b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
275  Value casted =
276  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
277  scf::ValueVector viewAndIndices{casted};
278  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
279  zero);
280  b.create<scf::YieldOp>(loc, viewAndIndices);
281  });
282 }
283 
284 /// Given an `xferOp` for which:
285 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
286 /// 2. a memref of single vector `alloc` has been allocated.
287 /// Produce IR resembling:
288 /// ```
289 /// %1:3 = scf.if (%inBounds) {
290 /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
291 /// memref.cast %A: memref<A...> to compatibleMemRefType
292 /// scf.yield %view, ... : compatibleMemRefType, index, index
293 /// } else {
294 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
295 /// %3 = vector.type_cast %extra_alloc :
296 /// memref<...> to memref<vector<...>>
297 /// store %2, %3[] : memref<vector<...>>
298 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
299 /// scf.yield %4, ... : compatibleMemRefType, index, index
300 /// }
301 /// ```
302 /// Return the produced scf::IfOp.
304  RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
305  Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
306  Location loc = xferOp.getLoc();
307  scf::IfOp fullPartialIfOp;
308  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
309  Value memref = xferOp.getSource();
310  return b.create<scf::IfOp>(
311  loc, inBoundsCond,
312  [&](OpBuilder &b, Location loc) {
313  Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
314  scf::ValueVector viewAndIndices{res};
315  viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
316  xferOp.getIndices().end());
317  b.create<scf::YieldOp>(loc, viewAndIndices);
318  },
319  [&](OpBuilder &b, Location loc) {
320  Operation *newXfer = b.clone(*xferOp.getOperation());
321  Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
322  b.create<memref::StoreOp>(
323  loc, vector,
324  b.create<vector::TypeCastOp>(
325  loc, MemRefType::get({}, vector.getType()), alloc));
326 
327  Value casted =
328  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
329  scf::ValueVector viewAndIndices{casted};
330  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
331  zero);
332  b.create<scf::YieldOp>(loc, viewAndIndices);
333  });
334 }
335 
336 /// Given an `xferOp` for which:
337 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
338 /// 2. a memref of single vector `alloc` has been allocated.
339 /// Produce IR resembling:
340 /// ```
341 /// %1:3 = scf.if (%inBounds) {
342 /// memref.cast %A: memref<A...> to compatibleMemRefType
343 /// scf.yield %view, ... : compatibleMemRefType, index, index
344 /// } else {
345 /// %3 = vector.type_cast %extra_alloc :
346 /// memref<...> to memref<vector<...>>
347 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
348 /// scf.yield %4, ... : compatibleMemRefType, index, index
349 /// }
350 /// ```
351 static ValueRange
352 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
353  TypeRange returnTypes, Value inBoundsCond,
354  MemRefType compatibleMemRefType, Value alloc) {
355  Location loc = xferOp.getLoc();
356  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
357  Value memref = xferOp.getSource();
358  return b
359  .create<scf::IfOp>(
360  loc, inBoundsCond,
361  [&](OpBuilder &b, Location loc) {
362  Value res =
363  castToCompatibleMemRefType(b, memref, compatibleMemRefType);
364  scf::ValueVector viewAndIndices{res};
365  viewAndIndices.insert(viewAndIndices.end(),
366  xferOp.getIndices().begin(),
367  xferOp.getIndices().end());
368  b.create<scf::YieldOp>(loc, viewAndIndices);
369  },
370  [&](OpBuilder &b, Location loc) {
371  Value casted =
372  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
373  scf::ValueVector viewAndIndices{casted};
374  viewAndIndices.insert(viewAndIndices.end(),
375  xferOp.getTransferRank(), zero);
376  b.create<scf::YieldOp>(loc, viewAndIndices);
377  })
378  ->getResults();
379 }
380 
381 /// Given an `xferOp` for which:
382 /// 1. `inBoundsCond` has been computed.
383 /// 2. a memref of single vector `alloc` has been allocated.
384 /// 3. it originally wrote to %view
385 /// Produce IR resembling:
386 /// ```
387 /// %notInBounds = arith.xori %inBounds, %true
388 /// scf.if (%notInBounds) {
389 /// %3 = subview %alloc [...][...][...]
390 /// %4 = subview %view [0, 0][...][...]
391 /// linalg.copy(%3, %4)
392 /// }
393 /// ```
395  vector::TransferWriteOp xferOp,
396  Value inBoundsCond, Value alloc) {
397  Location loc = xferOp.getLoc();
398  auto notInBounds = b.create<arith::XOrIOp>(
399  loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
400  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
401  IRRewriter rewriter(b);
402  std::pair<Value, Value> copyArgs = createSubViewIntersection(
403  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
404  alloc);
405  b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
406  b.create<scf::YieldOp>(loc, ValueRange{});
407  });
408 }
409 
410 /// Given an `xferOp` for which:
411 /// 1. `inBoundsCond` has been computed.
412 /// 2. a memref of single vector `alloc` has been allocated.
413 /// 3. it originally wrote to %view
414 /// Produce IR resembling:
415 /// ```
416 /// %notInBounds = arith.xori %inBounds, %true
417 /// scf.if (%notInBounds) {
418 /// %2 = load %alloc : memref<vector<...>>
419 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
420 /// }
421 /// ```
423  vector::TransferWriteOp xferOp,
424  Value inBoundsCond,
425  Value alloc) {
426  Location loc = xferOp.getLoc();
427  auto notInBounds = b.create<arith::XOrIOp>(
428  loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
429  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
430  IRMapping mapping;
431  Value load = b.create<memref::LoadOp>(
432  loc,
433  b.create<vector::TypeCastOp>(
434  loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
435  ValueRange());
436  mapping.map(xferOp.getVector(), load);
437  b.clone(*xferOp.getOperation(), mapping);
438  b.create<scf::YieldOp>(loc, ValueRange{});
439  });
440 }
441 
442 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
444  // Find the closest surrounding allocation scope that is not a known looping
445  // construct (putting alloca's in loops doesn't always lower to deallocation
446  // until the end of the loop).
447  Operation *scope = nullptr;
448  for (Operation *parent = op->getParentOp(); parent != nullptr;
449  parent = parent->getParentOp()) {
450  if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
451  scope = parent;
452  if (!isa<scf::ForOp, affine::AffineForOp>(parent))
453  break;
454  }
455  assert(scope && "Expected op to be inside automatic allocation scope");
456  return scope;
457 }
458 
459 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
460 /// masking) fastpath and a slowpath.
461 ///
462 /// For vector.transfer_read:
463 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
464 /// newly created conditional upon function return.
465 /// To accomodate for the fact that the original vector.transfer indexing may be
466 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
467 /// scf.if op returns a view and values of type index.
468 ///
469 /// Example (a 2-D vector.transfer_read):
470 /// ```
471 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
472 /// ```
473 /// is transformed into:
474 /// ```
475 /// %1:3 = scf.if (%inBounds) {
476 /// // fastpath, direct cast
477 /// memref.cast %A: memref<A...> to compatibleMemRefType
478 /// scf.yield %view : compatibleMemRefType, index, index
479 /// } else {
480 /// // slowpath, not in-bounds vector.transfer or linalg.copy.
481 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
482 /// scf.yield %4 : compatibleMemRefType, index, index
483 // }
484 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
485 /// ```
486 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
487 ///
488 /// For vector.transfer_write:
489 /// There are 2 conditional blocks. First a block to decide which memref and
490 /// indices to use for an unmasked, inbounds write. Then a conditional block to
491 /// further copy a partial buffer into the final result in the slow path case.
492 ///
493 /// Example (a 2-D vector.transfer_write):
494 /// ```
495 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
496 /// ```
497 /// is transformed into:
498 /// ```
499 /// %1:3 = scf.if (%inBounds) {
500 /// memref.cast %A: memref<A...> to compatibleMemRefType
501 /// scf.yield %view : compatibleMemRefType, index, index
502 /// } else {
503 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
504 /// scf.yield %4 : compatibleMemRefType, index, index
505 /// }
506 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
507 /// true]}
508 /// scf.if (%notInBounds) {
509 /// // slowpath: not in-bounds vector.transfer or linalg.copy.
510 /// }
511 /// ```
512 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
513 ///
514 /// Preconditions:
515 /// 1. `xferOp.getPermutationMap()` must be a minor identity map
516 /// 2. the rank of the `xferOp.getSource()` and the rank of the
517 /// `xferOp.getVector()` must be equal. This will be relaxed in the future
518 /// but requires rank-reducing subviews.
520  RewriterBase &b, VectorTransferOpInterface xferOp,
521  VectorTransformsOptions options, scf::IfOp *ifOp) {
522  if (options.vectorTransferSplit == VectorTransferSplit::None)
523  return failure();
524 
525  SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
526  auto inBoundsAttr = b.getBoolArrayAttr(bools);
527  if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
528  b.modifyOpInPlace(xferOp, [&]() {
529  xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
530  });
531  return success();
532  }
533 
534  // Assert preconditions. Additionally, keep the variables in an inner scope to
535  // ensure they aren't used in the wrong scopes further down.
536  {
537  assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
538  "Expected splitFullAndPartialTransferPrecondition to hold");
539 
540  auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
541  auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
542 
543  if (!(xferReadOp || xferWriteOp))
544  return failure();
545  if (xferWriteOp && xferWriteOp.getMask())
546  return failure();
547  if (xferReadOp && xferReadOp.getMask())
548  return failure();
549  }
550 
552  b.setInsertionPoint(xferOp);
553  Value inBoundsCond = createInBoundsCond(
554  b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
555  if (!inBoundsCond)
556  return failure();
557 
558  // Top of the function `alloc` for transient storage.
559  Value alloc;
560  {
562  Operation *scope = getAutomaticAllocationScope(xferOp);
563  assert(scope->getNumRegions() == 1 &&
564  "AutomaticAllocationScope with >1 regions");
565  b.setInsertionPointToStart(&scope->getRegion(0).front());
566  auto shape = xferOp.getVectorType().getShape();
567  Type elementType = xferOp.getVectorType().getElementType();
568  alloc = b.create<memref::AllocaOp>(scope->getLoc(),
569  MemRefType::get(shape, elementType),
570  ValueRange{}, b.getI64IntegerAttr(32));
571  }
572 
573  MemRefType compatibleMemRefType =
574  getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
575  cast<MemRefType>(alloc.getType()));
576  if (!compatibleMemRefType)
577  return failure();
578 
579  SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
580  b.getIndexType());
581  returnTypes[0] = compatibleMemRefType;
582 
583  if (auto xferReadOp =
584  dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
585  // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
586  scf::IfOp fullPartialIfOp =
587  options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
588  ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
589  inBoundsCond,
590  compatibleMemRefType, alloc)
591  : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
592  inBoundsCond, compatibleMemRefType,
593  alloc);
594  if (ifOp)
595  *ifOp = fullPartialIfOp;
596 
597  // Set existing read op to in-bounds, it always reads from a full buffer.
598  for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
599  xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
600 
601  b.modifyOpInPlace(xferOp, [&]() {
602  xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
603  });
604 
605  return success();
606  }
607 
608  auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
609 
610  // Decide which location to write the entire vector to.
611  auto memrefAndIndices = getLocationToWriteFullVec(
612  b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
613 
614  // Do an in bounds write to either the output or the extra allocated buffer.
615  // The operation is cloned to prevent deleting information needed for the
616  // later IR creation.
617  IRMapping mapping;
618  mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
619  mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
620  auto *clone = b.clone(*xferWriteOp, mapping);
621  clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
622 
623  // Create a potential copy from the allocated buffer to the final output in
624  // the slow path case.
625  if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
626  createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
627  else
628  createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
629 
630  b.eraseOp(xferOp);
631 
632  return success();
633 }
634 
635 namespace {
636 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
637 /// may take an extra filter to perform selection at a finer granularity.
638 struct VectorTransferFullPartialRewriter : public RewritePattern {
639  using FilterConstraintType =
640  std::function<LogicalResult(VectorTransferOpInterface op)>;
641 
642  explicit VectorTransferFullPartialRewriter(
643  MLIRContext *context,
645  FilterConstraintType filter =
646  [](VectorTransferOpInterface op) { return success(); },
647  PatternBenefit benefit = 1)
648  : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
649  filter(std::move(filter)) {}
650 
651  /// Performs the rewrite.
652  LogicalResult matchAndRewrite(Operation *op,
653  PatternRewriter &rewriter) const override;
654 
655 private:
657  FilterConstraintType filter;
658 };
659 
660 } // namespace
661 
662 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
663  Operation *op, PatternRewriter &rewriter) const {
664  auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
665  if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
666  failed(filter(xferOp)))
667  return failure();
668  return splitFullAndPartialTransfer(rewriter, xferOp, options);
669 }
670 
673  patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
674  options);
675 }
@ None
static llvm::ManagedStatic< PassManagerOptions > options
static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)
Given an xferOp for which:
static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fast path and a ...
static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)
Operates under a scoped context to build the intersection between the view xferOp....
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)
Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
static Operation * getAutomaticAllocationScope(Operation *op)
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, MemRefType compatibleMemRefType)
Casts the given memref to a compatible memref type.
Base type for affine expression.
Definition: AffineExpr.h:68
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:412
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:95
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:310
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1193
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:67
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Structure to control the behavior of vector transform patterns.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.