MLIR  19.0.0git
VectorTransferSplitRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent patterns to rewrite a vector.transfer
10 // op into a fully in-bounds part and a partial part.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <optional>
15 #include <type_traits>
16 
23 
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
28 
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 #define DEBUG_TYPE "vector-transfer-split"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 /// Build the condition to ensure that a particular VectorTransferOpInterface
42 /// is in-bounds.
44  VectorTransferOpInterface xferOp) {
45  assert(xferOp.getPermutationMap().isMinorIdentity() &&
46  "Expected minor identity map");
47  Value inBoundsCond;
48  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
49  // Zip over the resulting vector shape and memref indices.
50  // If the dimension is known to be in-bounds, it does not participate in
51  // the construction of `inBoundsCond`.
52  if (xferOp.isDimInBounds(resultIdx))
53  return;
54  // Fold or create the check that `index + vector_size` <= `memref_size`.
55  Location loc = xferOp.getLoc();
56  int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
58  b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
59  {xferOp.getIndices()[indicesIdx]});
60  OpFoldResult dimSz =
61  memref::getMixedSize(b, loc, xferOp.getSource(), indicesIdx);
62  auto maybeCstSum = getConstantIntValue(sum);
63  auto maybeCstDimSz = getConstantIntValue(dimSz);
64  if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
65  return;
66  Value cond =
67  b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
69  getValueOrCreateConstantIndexOp(b, loc, dimSz));
70  // Conjunction over all dims for which we are in-bounds.
71  if (inBoundsCond)
72  inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
73  else
74  inBoundsCond = cond;
75  });
76  return inBoundsCond;
77 }
78 
79 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
80 /// masking) fast path and a slow path.
81 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
82 /// newly created conditional upon function return.
83 /// To accommodate for the fact that the original vector.transfer indexing may
84 /// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
85 /// scf.if op returns a view and values of type index.
86 /// At this time, only vector.transfer_read case is implemented.
87 ///
88 /// Example (a 2-D vector.transfer_read):
89 /// ```
90 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
91 /// ```
92 /// is transformed into:
93 /// ```
94 /// %1:3 = scf.if (%inBounds) {
95 /// // fast path, direct cast
96 /// memref.cast %A: memref<A...> to compatibleMemRefType
97 /// scf.yield %view : compatibleMemRefType, index, index
98 /// } else {
99 /// // slow path, not in-bounds vector.transfer or linalg.copy.
100 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
101 /// scf.yield %4 : compatibleMemRefType, index, index
102 // }
103 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
104 /// ```
105 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
106 ///
107 /// Preconditions:
108 /// 1. `xferOp.getPermutationMap()` must be a minor identity map
109 /// 2. the rank of the `xferOp.memref()` and the rank of the
110 /// `xferOp.getVector()` must be equal. This will be relaxed in the future
111 /// but requires rank-reducing subviews.
112 static LogicalResult
113 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
114  // TODO: support 0-d corner case.
115  if (xferOp.getTransferRank() == 0)
116  return failure();
117 
118  // TODO: expand support to these 2 cases.
119  if (!xferOp.getPermutationMap().isMinorIdentity())
120  return failure();
121  // Must have some out-of-bounds dimension to be a candidate for splitting.
122  if (!xferOp.hasOutOfBoundsDim())
123  return failure();
124  // Don't split transfer operations directly under IfOp, this avoids applying
125  // the pattern recursively.
126  // TODO: improve the filtering condition to make it more applicable.
127  if (isa<scf::IfOp>(xferOp->getParentOp()))
128  return failure();
129  return success();
130 }
131 
132 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
133 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
134 /// return null; otherwise:
135 /// 1. if `aT` and `bT` are cast-compatible, return `aT`.
136 /// 2. else return a new MemRefType obtained by iterating over the shape and
137 /// strides and:
138 /// a. keeping the ones that are static and equal across `aT` and `bT`.
139 /// b. using a dynamic shape and/or stride for the dimensions that don't
140 /// agree.
141 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
142  if (memref::CastOp::areCastCompatible(aT, bT))
143  return aT;
144  if (aT.getRank() != bT.getRank())
145  return MemRefType();
146  int64_t aOffset, bOffset;
147  SmallVector<int64_t, 4> aStrides, bStrides;
148  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
149  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
150  aStrides.size() != bStrides.size())
151  return MemRefType();
152 
153  ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
154  int64_t resOffset;
155  SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
156  resStrides(bT.getRank(), 0);
157  for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
158  resShape[idx] =
159  (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
160  resStrides[idx] =
161  (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
162  }
163  resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
164  return MemRefType::get(
165  resShape, aT.getElementType(),
166  StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
167 }
168 
169 /// Casts the given memref to a compatible memref type. If the source memref has
170 /// a different address space than the target type, a `memref.memory_space_cast`
171 /// is first inserted, followed by a `memref.cast`.
173  MemRefType compatibleMemRefType) {
174  MemRefType sourceType = cast<MemRefType>(memref.getType());
175  Value res = memref;
176  if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
177  sourceType = MemRefType::get(
178  sourceType.getShape(), sourceType.getElementType(),
179  sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
180  res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
181  }
182  if (sourceType == compatibleMemRefType)
183  return res;
184  return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
185 }
186 
187 /// Operates under a scoped context to build the intersection between the
188 /// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`.
189 // TODO: view intersection/union/differences should be a proper std op.
190 static std::pair<Value, Value>
191 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
192  Value alloc) {
193  Location loc = xferOp.getLoc();
194  int64_t memrefRank = xferOp.getShapedType().getRank();
195  // TODO: relax this precondition, will require rank-reducing subviews.
196  assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
197  "Expected memref rank to match the alloc rank");
198  ValueRange leadingIndices =
199  xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
201  sizes.append(leadingIndices.begin(), leadingIndices.end());
202  auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
203  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
204  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
205  Value dimMemRef = b.create<memref::DimOp>(xferOp.getLoc(),
206  xferOp.getSource(), indicesIdx);
207  Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
208  Value index = xferOp.getIndices()[indicesIdx];
209  AffineExpr i, j, k;
210  bindDims(xferOp.getContext(), i, j, k);
212  AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
213  // affine_min(%dimMemRef - %index, %dimAlloc)
214  Value affineMin = b.create<affine::AffineMinOp>(
215  loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
216  sizes.push_back(affineMin);
217  });
218 
219  SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
220  xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
221  SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
222  SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
223  auto copySrc = b.create<memref::SubViewOp>(
224  loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides);
225  auto copyDest = b.create<memref::SubViewOp>(
226  loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides);
227  return std::make_pair(copySrc, copyDest);
228 }
229 
230 /// Given an `xferOp` for which:
231 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
232 /// 2. a memref of single vector `alloc` has been allocated.
233 /// Produce IR resembling:
234 /// ```
235 /// %1:3 = scf.if (%inBounds) {
236 /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
237 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType
238 /// scf.yield %view, ... : compatibleMemRefType, index, index
239 /// } else {
240 /// %2 = linalg.fill(%pad, %alloc)
241 /// %3 = subview %view [...][...][...]
242 /// %4 = subview %alloc [0, 0] [...] [...]
243 /// linalg.copy(%3, %4)
244 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
245 /// scf.yield %5, ... : compatibleMemRefType, index, index
246 /// }
247 /// ```
248 /// Return the produced scf::IfOp.
249 static scf::IfOp
250 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
251  TypeRange returnTypes, Value inBoundsCond,
252  MemRefType compatibleMemRefType, Value alloc) {
253  Location loc = xferOp.getLoc();
254  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
255  Value memref = xferOp.getSource();
256  return b.create<scf::IfOp>(
257  loc, inBoundsCond,
258  [&](OpBuilder &b, Location loc) {
259  Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
260  scf::ValueVector viewAndIndices{res};
261  viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
262  xferOp.getIndices().end());
263  b.create<scf::YieldOp>(loc, viewAndIndices);
264  },
265  [&](OpBuilder &b, Location loc) {
266  b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
267  ValueRange{alloc});
268  // Take partial subview of memref which guarantees no dimension
269  // overflows.
270  IRRewriter rewriter(b);
271  std::pair<Value, Value> copyArgs = createSubViewIntersection(
272  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
273  alloc);
274  b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
275  Value casted =
276  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
277  scf::ValueVector viewAndIndices{casted};
278  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
279  zero);
280  b.create<scf::YieldOp>(loc, viewAndIndices);
281  });
282 }
283 
284 /// Given an `xferOp` for which:
285 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
286 /// 2. a memref of single vector `alloc` has been allocated.
287 /// Produce IR resembling:
288 /// ```
289 /// %1:3 = scf.if (%inBounds) {
290 /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
291 /// memref.cast %A: memref<A...> to compatibleMemRefType
292 /// scf.yield %view, ... : compatibleMemRefType, index, index
293 /// } else {
294 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
295 /// %3 = vector.type_cast %extra_alloc :
296 /// memref<...> to memref<vector<...>>
297 /// store %2, %3[] : memref<vector<...>>
298 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
299 /// scf.yield %4, ... : compatibleMemRefType, index, index
300 /// }
301 /// ```
302 /// Return the produced scf::IfOp.
304  RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
305  Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
306  Location loc = xferOp.getLoc();
307  scf::IfOp fullPartialIfOp;
308  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
309  Value memref = xferOp.getSource();
310  return b.create<scf::IfOp>(
311  loc, inBoundsCond,
312  [&](OpBuilder &b, Location loc) {
313  Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
314  scf::ValueVector viewAndIndices{res};
315  viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
316  xferOp.getIndices().end());
317  b.create<scf::YieldOp>(loc, viewAndIndices);
318  },
319  [&](OpBuilder &b, Location loc) {
320  Operation *newXfer = b.clone(*xferOp.getOperation());
321  Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
322  b.create<memref::StoreOp>(
323  loc, vector,
324  b.create<vector::TypeCastOp>(
325  loc, MemRefType::get({}, vector.getType()), alloc));
326 
327  Value casted =
328  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
329  scf::ValueVector viewAndIndices{casted};
330  viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
331  zero);
332  b.create<scf::YieldOp>(loc, viewAndIndices);
333  });
334 }
335 
336 /// Given an `xferOp` for which:
337 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
338 /// 2. a memref of single vector `alloc` has been allocated.
339 /// Produce IR resembling:
340 /// ```
341 /// %1:3 = scf.if (%inBounds) {
342 /// memref.cast %A: memref<A...> to compatibleMemRefType
343 /// scf.yield %view, ... : compatibleMemRefType, index, index
344 /// } else {
345 /// %3 = vector.type_cast %extra_alloc :
346 /// memref<...> to memref<vector<...>>
347 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
348 /// scf.yield %4, ... : compatibleMemRefType, index, index
349 /// }
350 /// ```
351 static ValueRange
352 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
353  TypeRange returnTypes, Value inBoundsCond,
354  MemRefType compatibleMemRefType, Value alloc) {
355  Location loc = xferOp.getLoc();
356  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
357  Value memref = xferOp.getSource();
358  return b
359  .create<scf::IfOp>(
360  loc, inBoundsCond,
361  [&](OpBuilder &b, Location loc) {
362  Value res =
363  castToCompatibleMemRefType(b, memref, compatibleMemRefType);
364  scf::ValueVector viewAndIndices{res};
365  viewAndIndices.insert(viewAndIndices.end(),
366  xferOp.getIndices().begin(),
367  xferOp.getIndices().end());
368  b.create<scf::YieldOp>(loc, viewAndIndices);
369  },
370  [&](OpBuilder &b, Location loc) {
371  Value casted =
372  castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
373  scf::ValueVector viewAndIndices{casted};
374  viewAndIndices.insert(viewAndIndices.end(),
375  xferOp.getTransferRank(), zero);
376  b.create<scf::YieldOp>(loc, viewAndIndices);
377  })
378  ->getResults();
379 }
380 
381 /// Given an `xferOp` for which:
382 /// 1. `inBoundsCond` has been computed.
383 /// 2. a memref of single vector `alloc` has been allocated.
384 /// 3. it originally wrote to %view
385 /// Produce IR resembling:
386 /// ```
387 /// %notInBounds = arith.xori %inBounds, %true
388 /// scf.if (%notInBounds) {
389 /// %3 = subview %alloc [...][...][...]
390 /// %4 = subview %view [0, 0][...][...]
391 /// linalg.copy(%3, %4)
392 /// }
393 /// ```
395  vector::TransferWriteOp xferOp,
396  Value inBoundsCond, Value alloc) {
397  Location loc = xferOp.getLoc();
398  auto notInBounds = b.create<arith::XOrIOp>(
399  loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
400  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
401  IRRewriter rewriter(b);
402  std::pair<Value, Value> copyArgs = createSubViewIntersection(
403  rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
404  alloc);
405  b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
406  b.create<scf::YieldOp>(loc, ValueRange{});
407  });
408 }
409 
410 /// Given an `xferOp` for which:
411 /// 1. `inBoundsCond` has been computed.
412 /// 2. a memref of single vector `alloc` has been allocated.
413 /// 3. it originally wrote to %view
414 /// Produce IR resembling:
415 /// ```
416 /// %notInBounds = arith.xori %inBounds, %true
417 /// scf.if (%notInBounds) {
418 /// %2 = load %alloc : memref<vector<...>>
419 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
420 /// }
421 /// ```
423  vector::TransferWriteOp xferOp,
424  Value inBoundsCond,
425  Value alloc) {
426  Location loc = xferOp.getLoc();
427  auto notInBounds = b.create<arith::XOrIOp>(
428  loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
429  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
430  IRMapping mapping;
431  Value load = b.create<memref::LoadOp>(
432  loc,
433  b.create<vector::TypeCastOp>(
434  loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
435  ValueRange());
436  mapping.map(xferOp.getVector(), load);
437  b.clone(*xferOp.getOperation(), mapping);
438  b.create<scf::YieldOp>(loc, ValueRange{});
439  });
440 }
441 
442 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
444  // Find the closest surrounding allocation scope that is not a known looping
445  // construct (putting alloca's in loops doesn't always lower to deallocation
446  // until the end of the loop).
447  Operation *scope = nullptr;
448  for (Operation *parent = op->getParentOp(); parent != nullptr;
449  parent = parent->getParentOp()) {
450  if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
451  scope = parent;
452  if (!isa<scf::ForOp, affine::AffineForOp>(parent))
453  break;
454  }
455  assert(scope && "Expected op to be inside automatic allocation scope");
456  return scope;
457 }
458 
459 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
460 /// masking) fastpath and a slowpath.
461 ///
462 /// For vector.transfer_read:
463 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
464 /// newly created conditional upon function return.
465 /// To accomodate for the fact that the original vector.transfer indexing may be
466 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
467 /// scf.if op returns a view and values of type index.
468 ///
469 /// Example (a 2-D vector.transfer_read):
470 /// ```
471 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
472 /// ```
473 /// is transformed into:
474 /// ```
475 /// %1:3 = scf.if (%inBounds) {
476 /// // fastpath, direct cast
477 /// memref.cast %A: memref<A...> to compatibleMemRefType
478 /// scf.yield %view : compatibleMemRefType, index, index
479 /// } else {
480 /// // slowpath, not in-bounds vector.transfer or linalg.copy.
481 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
482 /// scf.yield %4 : compatibleMemRefType, index, index
483 // }
484 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
485 /// ```
486 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
487 ///
488 /// For vector.transfer_write:
489 /// There are 2 conditional blocks. First a block to decide which memref and
490 /// indices to use for an unmasked, inbounds write. Then a conditional block to
491 /// further copy a partial buffer into the final result in the slow path case.
492 ///
493 /// Example (a 2-D vector.transfer_write):
494 /// ```
495 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
496 /// ```
497 /// is transformed into:
498 /// ```
499 /// %1:3 = scf.if (%inBounds) {
500 /// memref.cast %A: memref<A...> to compatibleMemRefType
501 /// scf.yield %view : compatibleMemRefType, index, index
502 /// } else {
503 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
504 /// scf.yield %4 : compatibleMemRefType, index, index
505 /// }
506 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
507 /// true]}
508 /// scf.if (%notInBounds) {
509 /// // slowpath: not in-bounds vector.transfer or linalg.copy.
510 /// }
511 /// ```
512 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
513 ///
514 /// Preconditions:
515 /// 1. `xferOp.getPermutationMap()` must be a minor identity map
516 /// 2. the rank of the `xferOp.getSource()` and the rank of the
517 /// `xferOp.getVector()` must be equal. This will be relaxed in the future
518 /// but requires rank-reducing subviews.
520  RewriterBase &b, VectorTransferOpInterface xferOp,
521  VectorTransformsOptions options, scf::IfOp *ifOp) {
522  if (options.vectorTransferSplit == VectorTransferSplit::None)
523  return failure();
524 
525  SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
526  auto inBoundsAttr = b.getBoolArrayAttr(bools);
527  if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
528  b.modifyOpInPlace(xferOp, [&]() {
529  xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
530  });
531  return success();
532  }
533 
534  // Assert preconditions. Additionally, keep the variables in an inner scope to
535  // ensure they aren't used in the wrong scopes further down.
536  {
538  "Expected splitFullAndPartialTransferPrecondition to hold");
539 
540  auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
541  auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
542 
543  if (!(xferReadOp || xferWriteOp))
544  return failure();
545  if (xferWriteOp && xferWriteOp.getMask())
546  return failure();
547  if (xferReadOp && xferReadOp.getMask())
548  return failure();
549  }
550 
552  b.setInsertionPoint(xferOp);
553  Value inBoundsCond = createInBoundsCond(
554  b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
555  if (!inBoundsCond)
556  return failure();
557 
558  // Top of the function `alloc` for transient storage.
559  Value alloc;
560  {
562  Operation *scope = getAutomaticAllocationScope(xferOp);
563  assert(scope->getNumRegions() == 1 &&
564  "AutomaticAllocationScope with >1 regions");
565  b.setInsertionPointToStart(&scope->getRegion(0).front());
566  auto shape = xferOp.getVectorType().getShape();
567  Type elementType = xferOp.getVectorType().getElementType();
568  alloc = b.create<memref::AllocaOp>(scope->getLoc(),
569  MemRefType::get(shape, elementType),
570  ValueRange{}, b.getI64IntegerAttr(32));
571  }
572 
573  MemRefType compatibleMemRefType =
574  getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
575  cast<MemRefType>(alloc.getType()));
576  if (!compatibleMemRefType)
577  return failure();
578 
579  SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
580  b.getIndexType());
581  returnTypes[0] = compatibleMemRefType;
582 
583  if (auto xferReadOp =
584  dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
585  // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
586  scf::IfOp fullPartialIfOp =
587  options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
588  ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
589  inBoundsCond,
590  compatibleMemRefType, alloc)
591  : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
592  inBoundsCond, compatibleMemRefType,
593  alloc);
594  if (ifOp)
595  *ifOp = fullPartialIfOp;
596 
597  // Set existing read op to in-bounds, it always reads from a full buffer.
598  for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
599  xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
600 
601  b.modifyOpInPlace(xferOp, [&]() {
602  xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
603  });
604 
605  return success();
606  }
607 
608  auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
609 
610  // Decide which location to write the entire vector to.
611  auto memrefAndIndices = getLocationToWriteFullVec(
612  b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
613 
614  // Do an in bounds write to either the output or the extra allocated buffer.
615  // The operation is cloned to prevent deleting information needed for the
616  // later IR creation.
617  IRMapping mapping;
618  mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
619  mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
620  auto *clone = b.clone(*xferWriteOp, mapping);
621  clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
622 
623  // Create a potential copy from the allocated buffer to the final output in
624  // the slow path case.
625  if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
626  createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
627  else
628  createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
629 
630  b.eraseOp(xferOp);
631 
632  return success();
633 }
634 
635 namespace {
636 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
637 /// may take an extra filter to perform selection at a finer granularity.
638 struct VectorTransferFullPartialRewriter : public RewritePattern {
639  using FilterConstraintType =
640  std::function<LogicalResult(VectorTransferOpInterface op)>;
641 
642  explicit VectorTransferFullPartialRewriter(
643  MLIRContext *context,
645  FilterConstraintType filter =
646  [](VectorTransferOpInterface op) { return success(); },
647  PatternBenefit benefit = 1)
648  : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
649  filter(std::move(filter)) {}
650 
651  /// Performs the rewrite.
652  LogicalResult matchAndRewrite(Operation *op,
653  PatternRewriter &rewriter) const override;
654 
655 private:
657  FilterConstraintType filter;
658 };
659 
660 } // namespace
661 
662 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
663  Operation *op, PatternRewriter &rewriter) const {
664  auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
665  if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
666  failed(filter(xferOp)))
667  return failure();
668  return splitFullAndPartialTransfer(rewriter, xferOp, options);
669 }
670 
673  patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
674  options);
675 }
@ None
static llvm::ManagedStatic< PassManagerOptions > options
static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)
Given an xferOp for which:
static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fast path and a ...
static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)
Operates under a scoped context to build the intersection between the view xferOp....
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)
Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
static Operation * getAutomaticAllocationScope(Operation *op)
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, MemRefType compatibleMemRefType)
Casts the given memref to a compatible memref type.
Base type for affine expression.
Definition: AffineExpr.h:69
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:296
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:379
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:67
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Structure to control the behavior of vector transform patterns.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.