MLIR 23.0.0git
VectorTransferSplitRewritePatterns.cpp
Go to the documentation of this file.
1//===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent patterns to rewrite a vector.transfer
10// op into a fully in-bounds part and a partial part.
11//
12//===----------------------------------------------------------------------===//
13
14#include <optional>
15
22#include "llvm/ADT/SmallVectorExtras.h"
23
27
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallVectorExtras.h"
30
31#define DEBUG_TYPE "vector-transfer-split"
32
33using namespace mlir;
34using namespace mlir::vector;
35
36/// Build the condition to ensure that a particular VectorTransferOpInterface
37/// is in-bounds.
39 VectorTransferOpInterface xferOp) {
40 assert(xferOp.getPermutationMap().isMinorIdentity() &&
41 "Expected minor identity map");
42 Value inBoundsCond;
43 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
44 // Zip over the resulting vector shape and memref indices.
45 // If the dimension is known to be in-bounds, it does not participate in
46 // the construction of `inBoundsCond`.
47 if (xferOp.isDimInBounds(resultIdx))
48 return;
49 // Fold or create the check that `index + vector_size` <= `memref_size`.
50 Location loc = xferOp.getLoc();
51 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
53 b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
54 {xferOp.getIndices()[indicesIdx]});
55 OpFoldResult dimSz =
56 memref::getMixedSize(b, loc, xferOp.getBase(), indicesIdx);
57 auto maybeCstSum = getConstantIntValue(sum);
58 auto maybeCstDimSz = getConstantIntValue(dimSz);
59 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
60 return;
61 Value cond =
62 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sle,
65 // Conjunction over all dims for which we are in-bounds.
66 if (inBoundsCond)
67 inBoundsCond = arith::AndIOp::create(b, loc, inBoundsCond, cond);
68 else
69 inBoundsCond = cond;
70 });
71 return inBoundsCond;
72}
73
74/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
75/// masking) fast path and a slow path.
76/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
77/// newly created conditional upon function return.
78/// To accommodate for the fact that the original vector.transfer indexing may
79/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
80/// scf.if op returns a view and values of type index.
81/// At this time, only vector.transfer_read case is implemented.
82///
83/// Example (a 2-D vector.transfer_read):
84/// ```
85/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
86/// ```
87/// is transformed into:
88/// ```
89/// %1:3 = scf.if (%inBounds) {
90/// // fast path, direct cast
91/// memref.cast %A: memref<A...> to compatibleMemRefType
92/// scf.yield %view : compatibleMemRefType, index, index
93/// } else {
94/// // slow path, not in-bounds vector.transfer or linalg.copy.
95/// memref.cast %alloc: memref<B...> to compatibleMemRefType
96/// scf.yield %4 : compatibleMemRefType, index, index
97// }
98/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
99/// ```
100/// where `alloc` is a top of the function alloca'ed buffer of one vector.
101///
102/// Preconditions:
103/// 1. `xferOp.getPermutationMap()` must be a minor identity map
104/// 2. the rank of the `xferOp.memref()` and the rank of the
105/// `xferOp.getVector()` must be equal. This will be relaxed in the future
106/// but requires rank-reducing subviews.
107static LogicalResult
108splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
109 // TODO: support 0-d corner case.
110 if (xferOp.getTransferRank() == 0)
111 return failure();
112
113 // TODO: expand support to these 2 cases.
114 if (!xferOp.getPermutationMap().isMinorIdentity())
115 return failure();
116 // Must have some out-of-bounds dimension to be a candidate for splitting.
117 if (!xferOp.hasOutOfBoundsDim())
118 return failure();
119 // Don't split transfer operations directly under IfOp, this avoids applying
120 // the pattern recursively.
121 // TODO: improve the filtering condition to make it more applicable.
122 if (isa<scf::IfOp>(xferOp->getParentOp()))
123 return failure();
124 return success();
125}
126
127/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
128/// be cast. If the MemRefTypes don't have the same rank or are not strided,
129/// return null; otherwise:
130/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
131/// 2. else return a new MemRefType obtained by iterating over the shape and
132/// strides and:
133/// a. keeping the ones that are static and equal across `aT` and `bT`.
134/// b. using a dynamic shape and/or stride for the dimensions that don't
135/// agree.
136static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
137 if (memref::CastOp::areCastCompatible(aT, bT))
138 return aT;
139 if (aT.getRank() != bT.getRank())
140 return MemRefType();
141 int64_t aOffset, bOffset;
142 SmallVector<int64_t, 4> aStrides, bStrides;
143 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
144 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
145 aStrides.size() != bStrides.size())
146 return MemRefType();
147
148 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
149 int64_t resOffset;
150 SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
151 resStrides(bT.getRank(), 0);
152 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
153 resShape[idx] =
154 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
155 resStrides[idx] =
156 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
157 }
158 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
159 return MemRefType::get(
160 resShape, aT.getElementType(),
161 StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
162}
163
164/// Casts the given memref to a compatible memref type. If the source memref has
165/// a different address space than the target type, a `memref.memory_space_cast`
166/// is first inserted, followed by a `memref.cast`.
167static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
168 MemRefType compatibleMemRefType) {
169 MemRefType sourceType = cast<MemRefType>(memref.getType());
170 Value res = memref;
171 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
172 sourceType = MemRefType::get(
173 sourceType.getShape(), sourceType.getElementType(),
174 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
175 res =
176 memref::MemorySpaceCastOp::create(b, memref.getLoc(), sourceType, res);
177 }
178 if (sourceType == compatibleMemRefType)
179 return res;
180 return memref::CastOp::create(b, memref.getLoc(), compatibleMemRefType, res);
181}
182
183/// Operates under a scoped context to build the intersection between the
184/// view `xferOp.getbase()` @ `xferOp.getIndices()` and the view `alloc`.
185// TODO: view intersection/union/differences should be a proper std op.
186static std::pair<Value, Value>
187createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
188 Value alloc) {
189 Location loc = xferOp.getLoc();
190 int64_t memrefRank = xferOp.getShapedType().getRank();
191 // TODO: relax this precondition, will require rank-reducing subviews.
192 assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
193 "Expected memref rank to match the alloc rank");
194 ValueRange leadingIndices =
195 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
197 sizes.append(leadingIndices.begin(), leadingIndices.end());
198 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
199 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
200 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
201 Value dimMemRef =
202 memref::DimOp::create(b, xferOp.getLoc(), xferOp.getBase(), indicesIdx);
203 Value dimAlloc = memref::DimOp::create(b, loc, alloc, resultIdx);
204 Value index = xferOp.getIndices()[indicesIdx];
205 AffineExpr i, j, k;
206 bindDims(xferOp.getContext(), i, j, k);
208 AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
209 // affine_min(%dimMemRef - %index, %dimAlloc)
210 Value affineMin =
211 affine::AffineMinOp::create(b, loc, index.getType(), maps[0],
212 ValueRange{dimMemRef, index, dimAlloc});
213 sizes.push_back(affineMin);
214 });
215
216 SmallVector<OpFoldResult> srcIndices = llvm::map_to_vector<4>(
217 xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; });
218 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
219 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
220 auto copySrc = memref::SubViewOp::create(
221 b, loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
222 auto copyDest = memref::SubViewOp::create(
223 b, loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
224 return std::make_pair(copySrc, copyDest);
225}
226
227/// Given an `xferOp` for which:
228/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
229/// 2. a memref of single vector `alloc` has been allocated.
230/// Produce IR resembling:
231/// ```
232/// %1:3 = scf.if (%inBounds) {
233/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
234/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
235/// scf.yield %view, ... : compatibleMemRefType, index, index
236/// } else {
237/// %2 = linalg.fill(%pad, %alloc)
238/// %3 = subview %view [...][...][...]
239/// %4 = subview %alloc [0, 0] [...] [...]
240/// linalg.copy(%3, %4)
241/// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
242/// scf.yield %5, ... : compatibleMemRefType, index, index
243/// }
244/// ```
245/// Return the produced scf::IfOp.
246static scf::IfOp
247createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
248 TypeRange returnTypes, Value inBoundsCond,
249 MemRefType compatibleMemRefType, Value alloc) {
250 Location loc = xferOp.getLoc();
251 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
252 Value memref = xferOp.getBase();
253 return scf::IfOp::create(
254 b, loc, inBoundsCond,
255 [&](OpBuilder &b, Location loc) {
256 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
257 scf::ValueVector viewAndIndices{res};
258 llvm::append_range(viewAndIndices, xferOp.getIndices());
259 scf::YieldOp::create(b, loc, viewAndIndices);
260 },
261 [&](OpBuilder &b, Location loc) {
262 linalg::FillOp::create(b, loc, ValueRange{xferOp.getPadding()},
263 ValueRange{alloc});
264 // Take partial subview of memref which guarantees no dimension
265 // overflows.
266 IRRewriter rewriter(b);
267 std::pair<Value, Value> copyArgs = createSubViewIntersection(
268 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
269 alloc);
270 memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
271 Value casted =
272 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
273 scf::ValueVector viewAndIndices{casted};
274 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
275 zero);
276 scf::YieldOp::create(b, loc, viewAndIndices);
277 });
278}
279
280/// Given an `xferOp` for which:
281/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
282/// 2. a memref of single vector `alloc` has been allocated.
283/// Produce IR resembling:
284/// ```
285/// %1:3 = scf.if (%inBounds) {
286/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
287/// memref.cast %A: memref<A...> to compatibleMemRefType
288/// scf.yield %view, ... : compatibleMemRefType, index, index
289/// } else {
290/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
291/// %3 = vector.type_cast %extra_alloc :
292/// memref<...> to memref<vector<...>>
293/// store %2, %3[] : memref<vector<...>>
294/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
295/// scf.yield %4, ... : compatibleMemRefType, index, index
296/// }
297/// ```
298/// Return the produced scf::IfOp.
299static scf::IfOp createFullPartialVectorTransferRead(
300 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
301 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
302 Location loc = xferOp.getLoc();
303 scf::IfOp fullPartialIfOp;
304 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
305 Value memref = xferOp.getBase();
306 return scf::IfOp::create(
307 b, loc, inBoundsCond,
308 [&](OpBuilder &b, Location loc) {
309 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
310 scf::ValueVector viewAndIndices{res};
311 llvm::append_range(viewAndIndices, xferOp.getIndices());
312 scf::YieldOp::create(b, loc, viewAndIndices);
313 },
314 [&](OpBuilder &b, Location loc) {
315 Operation *newXfer = b.clone(*xferOp.getOperation());
316 Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
317 memref::StoreOp::create(
318 b, loc, vector,
319 vector::TypeCastOp::create(
320 b, loc, MemRefType::get({}, vector.getType()), alloc));
321
322 Value casted =
323 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
324 scf::ValueVector viewAndIndices{casted};
325 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
326 zero);
327 scf::YieldOp::create(b, loc, viewAndIndices);
328 });
329}
330
331/// Given an `xferOp` for which:
332/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
333/// 2. a memref of single vector `alloc` has been allocated.
334/// Produce IR resembling:
335/// ```
336/// %1:3 = scf.if (%inBounds) {
337/// memref.cast %A: memref<A...> to compatibleMemRefType
338/// scf.yield %view, ... : compatibleMemRefType, index, index
339/// } else {
340/// %3 = vector.type_cast %extra_alloc :
341/// memref<...> to memref<vector<...>>
342/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
343/// scf.yield %4, ... : compatibleMemRefType, index, index
344/// }
345/// ```
346static ValueRange
347getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
348 TypeRange returnTypes, Value inBoundsCond,
349 MemRefType compatibleMemRefType, Value alloc) {
350 Location loc = xferOp.getLoc();
351 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
352 Value memref = xferOp.getBase();
353 return scf::IfOp::create(
354 b, loc, inBoundsCond,
355 [&](OpBuilder &b, Location loc) {
356 Value res =
357 castToCompatibleMemRefType(b, memref, compatibleMemRefType);
358 scf::ValueVector viewAndIndices{res};
359 llvm::append_range(viewAndIndices, xferOp.getIndices());
360 scf::YieldOp::create(b, loc, viewAndIndices);
361 },
362 [&](OpBuilder &b, Location loc) {
363 Value casted =
364 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
365 scf::ValueVector viewAndIndices{casted};
366 viewAndIndices.insert(viewAndIndices.end(),
367 xferOp.getTransferRank(), zero);
368 scf::YieldOp::create(b, loc, viewAndIndices);
369 })
370 ->getResults();
371}
372
373/// Given an `xferOp` for which:
374/// 1. `inBoundsCond` has been computed.
375/// 2. a memref of single vector `alloc` has been allocated.
376/// 3. it originally wrote to %view
377/// Produce IR resembling:
378/// ```
379/// %notInBounds = arith.xori %inBounds, %true
380/// scf.if (%notInBounds) {
381/// %3 = subview %alloc [...][...][...]
382/// %4 = subview %view [0, 0][...][...]
383/// linalg.copy(%3, %4)
384/// }
385/// ```
386static void createFullPartialLinalgCopy(RewriterBase &b,
387 vector::TransferWriteOp xferOp,
388 Value inBoundsCond, Value alloc) {
389 Location loc = xferOp.getLoc();
390 auto notInBounds = arith::XOrIOp::create(
391 b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1));
392 scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) {
393 IRRewriter rewriter(b);
394 std::pair<Value, Value> copyArgs = createSubViewIntersection(
395 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
396 alloc);
397 memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
398 scf::YieldOp::create(b, loc, ValueRange{});
399 });
400}
401
402/// Given an `xferOp` for which:
403/// 1. `inBoundsCond` has been computed.
404/// 2. a memref of single vector `alloc` has been allocated.
405/// 3. it originally wrote to %view
406/// Produce IR resembling:
407/// ```
408/// %notInBounds = arith.xori %inBounds, %true
409/// scf.if (%notInBounds) {
410/// %2 = load %alloc : memref<vector<...>>
411/// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
412/// }
413/// ```
414static void createFullPartialVectorTransferWrite(RewriterBase &b,
415 vector::TransferWriteOp xferOp,
416 Value inBoundsCond,
417 Value alloc) {
418 Location loc = xferOp.getLoc();
419 auto notInBounds = arith::XOrIOp::create(
420 b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1));
421 scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) {
422 IRMapping mapping;
423 Value load = memref::LoadOp::create(
424 b, loc,
425 vector::TypeCastOp::create(
426 b, loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
427 ValueRange());
428 mapping.map(xferOp.getVector(), load);
429 b.clone(*xferOp.getOperation(), mapping);
430 scf::YieldOp::create(b, loc, ValueRange{});
431 });
432}
433
434// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
435static Operation *getAutomaticAllocationScope(Operation *op) {
436 // Find the closest surrounding allocation scope that is not a known looping
437 // construct (putting alloca's in loops doesn't always lower to deallocation
438 // until the end of the loop).
439 Operation *scope = nullptr;
440 for (Operation *parent = op->getParentOp(); parent != nullptr;
441 parent = parent->getParentOp()) {
442 if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
443 scope = parent;
444 if (!isa<scf::ForOp, affine::AffineForOp>(parent))
445 break;
446 }
447 assert(scope && "Expected op to be inside automatic allocation scope");
448 return scope;
449}
450
451/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
452/// masking) fastpath and a slowpath.
453///
454/// For vector.transfer_read:
455/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
456/// newly created conditional upon function return.
457/// To accomodate for the fact that the original vector.transfer indexing may be
458/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
459/// scf.if op returns a view and values of type index.
460///
461/// Example (a 2-D vector.transfer_read):
462/// ```
463/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
464/// ```
465/// is transformed into:
466/// ```
467/// %1:3 = scf.if (%inBounds) {
468/// // fastpath, direct cast
469/// memref.cast %A: memref<A...> to compatibleMemRefType
470/// scf.yield %view : compatibleMemRefType, index, index
471/// } else {
472/// // slowpath, not in-bounds vector.transfer or linalg.copy.
473/// memref.cast %alloc: memref<B...> to compatibleMemRefType
474/// scf.yield %4 : compatibleMemRefType, index, index
475// }
476/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
477/// ```
478/// where `alloc` is a top of the function alloca'ed buffer of one vector.
479///
480/// For vector.transfer_write:
481/// There are 2 conditional blocks. First a block to decide which memref and
482/// indices to use for an unmasked, inbounds write. Then a conditional block to
483/// further copy a partial buffer into the final result in the slow path case.
484///
485/// Example (a 2-D vector.transfer_write):
486/// ```
487/// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
488/// ```
489/// is transformed into:
490/// ```
491/// %1:3 = scf.if (%inBounds) {
492/// memref.cast %A: memref<A...> to compatibleMemRefType
493/// scf.yield %view : compatibleMemRefType, index, index
494/// } else {
495/// memref.cast %alloc: memref<B...> to compatibleMemRefType
496/// scf.yield %4 : compatibleMemRefType, index, index
497/// }
498/// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
499/// true]}
500/// scf.if (%notInBounds) {
501/// // slowpath: not in-bounds vector.transfer or linalg.copy.
502/// }
503/// ```
504/// where `alloc` is a top of the function alloca'ed buffer of one vector.
505///
506/// Preconditions:
507/// 1. `xferOp.getPermutationMap()` must be a minor identity map
508/// 2. the rank of the `xferOp.getBase()` and the rank of the
509/// `xferOp.getVector()` must be equal. This will be relaxed in the future
510/// but requires rank-reducing subviews.
511LogicalResult mlir::vector::splitFullAndPartialTransfer(
512 RewriterBase &b, VectorTransferOpInterface xferOp,
513 VectorTransformsOptions options, scf::IfOp *ifOp) {
514 if (options.vectorTransferSplit == VectorTransferSplit::None)
515 return failure();
516
517 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
518 auto inBoundsAttr = b.getBoolArrayAttr(bools);
519 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
520 b.modifyOpInPlace(xferOp, [&]() {
521 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
522 });
523 return success();
524 }
525
526 // Assert preconditions. Additionally, keep the variables in an inner scope to
527 // ensure they aren't used in the wrong scopes further down.
528 {
529 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
530 "Expected splitFullAndPartialTransferPrecondition to hold");
531
532 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
533 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
534
535 if (!(xferReadOp || xferWriteOp))
536 return failure();
537 if (xferWriteOp && xferWriteOp.getMask())
538 return failure();
539 if (xferReadOp && xferReadOp.getMask())
540 return failure();
541 }
542
544 b.setInsertionPoint(xferOp);
545 Value inBoundsCond = createInBoundsCond(
546 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
547 if (!inBoundsCond)
548 return failure();
549
550 // Top of the function `alloc` for transient storage.
551 Value alloc;
552 {
554 Operation *scope = getAutomaticAllocationScope(xferOp);
555 assert(scope->getNumRegions() == 1 &&
556 "AutomaticAllocationScope with >1 regions");
557 b.setInsertionPointToStart(&scope->getRegion(0).front());
558 auto shape = xferOp.getVectorType().getShape();
559 Type elementType = xferOp.getVectorType().getElementType();
560 alloc = memref::AllocaOp::create(b, scope->getLoc(),
561 MemRefType::get(shape, elementType),
562 ValueRange{}, b.getI64IntegerAttr(32));
563 }
564
565 MemRefType compatibleMemRefType =
566 getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
567 cast<MemRefType>(alloc.getType()));
568 if (!compatibleMemRefType)
569 return failure();
570
571 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
572 b.getIndexType());
573 returnTypes[0] = compatibleMemRefType;
574
575 if (auto xferReadOp =
576 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
577 // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
578 scf::IfOp fullPartialIfOp =
579 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
580 ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
581 inBoundsCond,
582 compatibleMemRefType, alloc)
583 : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
584 inBoundsCond, compatibleMemRefType,
585 alloc);
586 if (ifOp)
587 *ifOp = fullPartialIfOp;
588
589 // Set existing read op to in-bounds, it always reads from a full buffer.
590 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
591 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
592
593 b.modifyOpInPlace(xferOp, [&]() {
594 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
595 });
596
597 return success();
598 }
599
600 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
601
602 // Decide which location to write the entire vector to.
603 auto memrefAndIndices = getLocationToWriteFullVec(
604 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
605
606 // Do an in bounds write to either the output or the extra allocated buffer.
607 // The operation is cloned to prevent deleting information needed for the
608 // later IR creation.
609 IRMapping mapping;
610 mapping.map(xferWriteOp.getBase(), memrefAndIndices.front());
611 mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
612 auto *clone = b.clone(*xferWriteOp, mapping);
613 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
614
615 // Create a potential copy from the allocated buffer to the final output in
616 // the slow path case.
617 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
618 createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
619 else
620 createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
621
622 b.eraseOp(xferOp);
623
624 return success();
625}
626
627namespace {
628/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
629/// may take an extra filter to perform selection at a finer granularity.
630struct VectorTransferFullPartialRewriter : public RewritePattern {
631 using FilterConstraintType =
632 std::function<LogicalResult(VectorTransferOpInterface op)>;
633
634 explicit VectorTransferFullPartialRewriter(
635 MLIRContext *context,
636 VectorTransformsOptions options = VectorTransformsOptions(),
637 FilterConstraintType filter =
638 [](VectorTransferOpInterface op) { return success(); },
639 PatternBenefit benefit = 1)
640 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
641 filter(std::move(filter)) {}
642
643 /// Performs the rewrite.
644 LogicalResult matchAndRewrite(Operation *op,
645 PatternRewriter &rewriter) const override;
646
647private:
648 VectorTransformsOptions options;
649 FilterConstraintType filter;
650};
651
652} // namespace
653
654LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
655 Operation *op, PatternRewriter &rewriter) const {
656 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
657 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
658 failed(filter(xferOp)))
659 return failure();
660 return splitFullAndPartialTransfer(rewriter, xferOp, options);
661}
662
665 patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
666 options);
667}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
Base type for affine expression.
Definition AffineExpr.h:68
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:715
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:703
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:255
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Block & front()
Definition Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:70
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Structure to control the behavior of vector transform patterns.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.