MLIR 23.0.0git
VectorTransferSplitRewritePatterns.cpp
Go to the documentation of this file.
1//===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent patterns to rewrite a vector.transfer
10// op into a fully in-bounds part and a partial part.
11//
12//===----------------------------------------------------------------------===//
13
14#include <optional>
15
22#include "llvm/ADT/SmallVectorExtras.h"
23
27
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallVectorExtras.h"
30
31#define DEBUG_TYPE "vector-transfer-split"
32
33using namespace mlir;
34using namespace mlir::vector;
35
36/// Build the condition to ensure that a particular VectorTransferOpInterface
37/// is in-bounds.
39 VectorTransferOpInterface xferOp) {
40 assert(xferOp.getPermutationMap().isMinorIdentity() &&
41 "Expected minor identity map");
42 Value inBoundsCond;
43 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
44 // Zip over the resulting vector shape and memref indices.
45 // If the dimension is known to be in-bounds, it does not participate in
46 // the construction of `inBoundsCond`.
47 if (xferOp.isDimInBounds(resultIdx))
48 return;
49 // Fold or create the check that `index + vector_size` <= `memref_size`.
50 Location loc = xferOp.getLoc();
51 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
53 b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
54 {xferOp.getIndices()[indicesIdx]});
55 OpFoldResult dimSz =
56 memref::getMixedSize(b, loc, xferOp.getBase(), indicesIdx);
57 auto maybeCstSum = getConstantIntValue(sum);
58 auto maybeCstDimSz = getConstantIntValue(dimSz);
59 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
60 return;
61 Value cond =
62 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sle,
65 // Conjunction over all dims for which we are in-bounds.
66 if (inBoundsCond)
67 inBoundsCond = arith::AndIOp::create(b, loc, inBoundsCond, cond);
68 else
69 inBoundsCond = cond;
70 });
71 return inBoundsCond;
72}
73
74/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
75/// masking) fast path and a slow path.
76/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
77/// newly created conditional upon function return.
78/// To accommodate for the fact that the original vector.transfer indexing may
79/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
80/// scf.if op returns a view and values of type index.
81/// At this time, only vector.transfer_read case is implemented.
82///
83/// Example (a 2-D vector.transfer_read):
84/// ```
85/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
86/// ```
87/// is transformed into:
88/// ```
89/// %1:3 = scf.if (%inBounds) {
90/// // fast path, direct cast
91/// memref.cast %A: memref<A...> to compatibleMemRefType
92/// scf.yield %view : compatibleMemRefType, index, index
93/// } else {
94/// // slow path, not in-bounds vector.transfer or linalg.copy.
95/// memref.cast %alloc: memref<B...> to compatibleMemRefType
96/// scf.yield %4 : compatibleMemRefType, index, index
97// }
98/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
99/// ```
100/// where `alloc` is a top of the function alloca'ed buffer of one vector.
101///
102/// Preconditions:
103/// 1. `xferOp.getPermutationMap()` must be a minor identity map
104/// 2. the rank of the `xferOp.memref()` and the rank of the
105/// `xferOp.getVector()` must be equal. This will be relaxed in the future
106/// but requires rank-reducing subviews.
107static LogicalResult
108splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
109 // TODO: support 0-d corner case.
110 if (xferOp.getTransferRank() == 0)
111 return failure();
112
113 // TODO: expand support to these 2 cases.
114 if (!xferOp.getPermutationMap().isMinorIdentity())
115 return failure();
116 // Must have some out-of-bounds dimension to be a candidate for splitting.
117 if (!xferOp.hasOutOfBoundsDim())
118 return failure();
119 // Don't split transfer operations directly under IfOp, this avoids applying
120 // the pattern recursively.
121 // TODO: improve the filtering condition to make it more applicable.
122 if (isa<scf::IfOp>(xferOp->getParentOp()))
123 return failure();
124 return success();
125}
126
127/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
128/// be cast. If the MemRefTypes don't have the same rank or are not strided,
129/// return null; otherwise:
130/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
131/// 2. else return a new MemRefType obtained by iterating over the shape and
132/// strides and:
133/// a. keeping the ones that are static and equal across `aT` and `bT`.
134/// b. using a dynamic shape and/or stride for the dimensions that don't
135/// agree.
136static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
137 if (memref::CastOp::areCastCompatible(aT, bT))
138 return aT;
139 if (aT.getRank() != bT.getRank())
140 return MemRefType();
141 int64_t aOffset, bOffset;
142 SmallVector<int64_t, 4> aStrides, bStrides;
143 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
144 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
145 aStrides.size() != bStrides.size())
146 return MemRefType();
147
148 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
149 int64_t resOffset;
150 SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
151 resStrides(bT.getRank(), 0);
152 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
153 resShape[idx] =
154 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
155 resStrides[idx] =
156 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
157 }
158 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
159 return MemRefType::get(
160 resShape, aT.getElementType(),
161 StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
162}
163
164/// Casts the given memref to a compatible memref type. If the source memref has
165/// a different address space than the target type, a `memref.memory_space_cast`
166/// is first inserted, followed by a `memref.cast`.
167static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
168 MemRefType compatibleMemRefType) {
169 MemRefType sourceType = cast<MemRefType>(memref.getType());
170 Value res = memref;
171 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
172 sourceType = MemRefType::get(
173 sourceType.getShape(), sourceType.getElementType(),
174 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
175 res =
176 memref::MemorySpaceCastOp::create(b, memref.getLoc(), sourceType, res);
177 }
178 if (sourceType == compatibleMemRefType)
179 return res;
180 return memref::CastOp::create(b, memref.getLoc(), compatibleMemRefType, res);
181}
182
183/// Operates under a scoped context to build the intersection between the
184/// view `xferOp.getbase()` @ `xferOp.getIndices()` and the view `alloc`.
185// TODO: view intersection/union/differences should be a proper std op.
186static std::pair<Value, Value>
187createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
188 Value alloc) {
189 Location loc = xferOp.getLoc();
190 int64_t memrefRank = xferOp.getShapedType().getRank();
191 // TODO: relax this precondition, will require rank-reducing subviews.
192 assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
193 "Expected memref rank to match the alloc rank");
194 ValueRange leadingIndices =
195 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
197 sizes.append(leadingIndices.begin(), leadingIndices.end());
198 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
199 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
200 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
201 Value dimMemRef =
202 memref::DimOp::create(b, xferOp.getLoc(), xferOp.getBase(), indicesIdx);
203 Value dimAlloc = memref::DimOp::create(b, loc, alloc, resultIdx);
204 Value index = xferOp.getIndices()[indicesIdx];
205 AffineExpr i, j, k;
206 bindDims(xferOp.getContext(), i, j, k);
208 AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
209 // affine_min(%dimMemRef - %index, %dimAlloc)
210 Value affineMin =
211 affine::AffineMinOp::create(b, loc, index.getType(), maps[0],
212 ValueRange{dimMemRef, index, dimAlloc});
213 sizes.push_back(affineMin);
214 });
215
216 SmallVector<OpFoldResult> srcIndices = llvm::map_to_vector<4>(
217 xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; });
218 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
219 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
220 auto copySrc = memref::SubViewOp::create(
221 b, loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
222 auto copyDest = memref::SubViewOp::create(
223 b, loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
224 return std::make_pair(copySrc, copyDest);
225}
226
227/// Given an `xferOp` for which:
228/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
229/// 2. a memref of single vector `alloc` has been allocated.
230/// Produce IR resembling:
231/// ```
232/// %1:3 = scf.if (%inBounds) {
233/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
234/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
235/// scf.yield %view, ... : compatibleMemRefType, index, index
236/// } else {
237/// %2 = linalg.fill(%pad, %alloc)
238/// %3 = subview %view [...][...][...]
239/// %4 = subview %alloc [0, 0] [...] [...]
240/// linalg.copy(%3, %4)
241/// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
242/// scf.yield %5, ... : compatibleMemRefType, index, index
243/// }
244/// ```
245/// Return the produced scf::IfOp.
246static scf::IfOp
247createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
248 TypeRange returnTypes, Value inBoundsCond,
249 MemRefType compatibleMemRefType, Value alloc) {
250 Location loc = xferOp.getLoc();
251 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
252 Value memref = xferOp.getBase();
253 return scf::IfOp::create(
254 b, loc, inBoundsCond,
255 [&](OpBuilder &b, Location loc) {
256 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
257 scf::ValueVector viewAndIndices{res};
258 llvm::append_range(viewAndIndices, xferOp.getIndices());
259 scf::YieldOp::create(b, loc, viewAndIndices);
260 },
261 [&](OpBuilder &b, Location loc) {
262 linalg::FillOp::create(b, loc, ValueRange{xferOp.getPadding()},
263 ValueRange{alloc});
264 // Take partial subview of memref which guarantees no dimension
265 // overflows.
266 IRRewriter rewriter(b);
267 std::pair<Value, Value> copyArgs = createSubViewIntersection(
268 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
269 alloc);
270 memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
271 Value casted =
272 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
273 scf::ValueVector viewAndIndices{casted};
274 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
275 zero);
276 scf::YieldOp::create(b, loc, viewAndIndices);
277 });
278}
279
280/// Given an `xferOp` for which:
281/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
282/// 2. a memref of single vector `alloc` has been allocated.
283/// Produce IR resembling:
284/// ```
285/// %1:3 = scf.if (%inBounds) {
286/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
287/// memref.cast %A: memref<A...> to compatibleMemRefType
288/// scf.yield %view, ... : compatibleMemRefType, index, index
289/// } else {
290/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
291/// %3 = vector.type_cast %extra_alloc :
292/// memref<...> to memref<vector<...>>
293/// store %2, %3[] : memref<vector<...>>
294/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
295/// scf.yield %4, ... : compatibleMemRefType, index, index
296/// }
297/// ```
298/// Return the produced scf::IfOp.
299static scf::IfOp createFullPartialVectorTransferRead(
300 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
301 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
302 Location loc = xferOp.getLoc();
303 scf::IfOp fullPartialIfOp;
304 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
305 Value memref = xferOp.getBase();
306 return scf::IfOp::create(
307 b, loc, inBoundsCond,
308 [&](OpBuilder &b, Location loc) {
309 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
310 scf::ValueVector viewAndIndices{res};
311 llvm::append_range(viewAndIndices, xferOp.getIndices());
312 scf::YieldOp::create(b, loc, viewAndIndices);
313 },
314 [&](OpBuilder &b, Location loc) {
315 Operation *newXfer = b.clone(*xferOp.getOperation());
316 Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
317 memref::StoreOp::create(
318 b, loc, vector,
319 vector::TypeCastOp::create(
320 b, loc, MemRefType::get({}, vector.getType()), alloc));
321
322 Value casted =
323 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
324 scf::ValueVector viewAndIndices{casted};
325 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
326 zero);
327 scf::YieldOp::create(b, loc, viewAndIndices);
328 });
329}
330
331/// Given an `xferOp` for which:
332/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
333/// 2. a memref of single vector `alloc` has been allocated.
334/// Produce IR resembling:
335/// ```
336/// %1:3 = scf.if (%inBounds) {
337/// memref.cast %A: memref<A...> to compatibleMemRefType
338/// scf.yield %view, ... : compatibleMemRefType, index, index
339/// } else {
340/// %3 = vector.type_cast %extra_alloc :
341/// memref<...> to memref<vector<...>>
342/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
343/// scf.yield %4, ... : compatibleMemRefType, index, index
344/// }
345/// ```
346static ValueRange
347getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
348 TypeRange returnTypes, Value inBoundsCond,
349 MemRefType compatibleMemRefType, Value alloc) {
350 Location loc = xferOp.getLoc();
351 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
352 Value memref = xferOp.getBase();
353 return scf::IfOp::create(
354 b, loc, inBoundsCond,
355 [&](OpBuilder &b, Location loc) {
356 Value res =
357 castToCompatibleMemRefType(b, memref, compatibleMemRefType);
358 scf::ValueVector viewAndIndices{res};
359 llvm::append_range(viewAndIndices, xferOp.getIndices());
360 scf::YieldOp::create(b, loc, viewAndIndices);
361 },
362 [&](OpBuilder &b, Location loc) {
363 Value casted =
364 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
365 scf::ValueVector viewAndIndices{casted};
366 viewAndIndices.insert(viewAndIndices.end(),
367 xferOp.getTransferRank(), zero);
368 scf::YieldOp::create(b, loc, viewAndIndices);
369 })
370 ->getResults();
371}
372
373/// Given an `xferOp` for which:
374/// 1. `inBoundsCond` has been computed.
375/// 2. a memref of single vector `alloc` has been allocated.
376/// 3. it originally wrote to %view
377/// Produce IR resembling:
378/// ```
379/// %notInBounds = arith.xori %inBounds, %true
380/// scf.if (%notInBounds) {
381/// %3 = subview %alloc [...][...][...]
382/// %4 = subview %view [0, 0][...][...]
383/// linalg.copy(%3, %4)
384/// }
385/// ```
386static void createFullPartialLinalgCopy(RewriterBase &b,
387 vector::TransferWriteOp xferOp,
388 Value inBoundsCond, Value alloc) {
389 Location loc = xferOp.getLoc();
390 auto notInBounds = arith::XOrIOp::create(
391 b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1));
392 scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) {
393 IRRewriter rewriter(b);
394 std::pair<Value, Value> copyArgs = createSubViewIntersection(
395 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
396 alloc);
397 memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
398 scf::YieldOp::create(b, loc, ValueRange{});
399 });
400}
401
402/// Given an `xferOp` for which:
403/// 1. `inBoundsCond` has been computed.
404/// 2. a memref of single vector `alloc` has been allocated.
405/// 3. it originally wrote to %view
406/// Produce IR resembling:
407/// ```
408/// %notInBounds = arith.xori %inBounds, %true
409/// scf.if (%notInBounds) {
410/// %2 = load %alloc : memref<vector<...>>
411/// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
412/// }
413/// ```
414static void createFullPartialVectorTransferWrite(RewriterBase &b,
415 vector::TransferWriteOp xferOp,
416 Value inBoundsCond,
417 Value alloc) {
418 Location loc = xferOp.getLoc();
419 auto notInBounds = arith::XOrIOp::create(
420 b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1));
421 scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) {
422 IRMapping mapping;
423 Value load = memref::LoadOp::create(
424 b, loc,
425 vector::TypeCastOp::create(
426 b, loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
427 ValueRange());
428 mapping.map(xferOp.getVector(), load);
429 b.clone(*xferOp.getOperation(), mapping);
430 scf::YieldOp::create(b, loc, ValueRange{});
431 });
432}
433
434// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
435static Operation *getAutomaticAllocationScope(Operation *op) {
436 // Find the closest surrounding allocation scope that is not a known looping
437 // construct (putting alloca's in loops doesn't always lower to deallocation
438 // until the end of the loop).
439 Operation *scope = nullptr;
440 for (Operation *parent = op->getParentOp(); parent != nullptr;
441 parent = parent->getParentOp()) {
442 if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
443 scope = parent;
444 if (!isa<scf::ForOp, affine::AffineForOp>(parent))
445 break;
446 }
447 assert(scope && "Expected op to be inside automatic allocation scope");
448 return scope;
449}
450
451/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
452/// masking) fastpath and a slowpath.
453///
454/// For vector.transfer_read:
455/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
456/// newly created conditional upon function return.
457/// To accomodate for the fact that the original vector.transfer indexing may be
458/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
459/// scf.if op returns a view and values of type index.
460///
461/// Example (a 2-D vector.transfer_read):
462/// ```
463/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
464/// ```
465/// is transformed into:
466/// ```
467/// %1:3 = scf.if (%inBounds) {
468/// // fastpath, direct cast
469/// memref.cast %A: memref<A...> to compatibleMemRefType
470/// scf.yield %view : compatibleMemRefType, index, index
471/// } else {
472/// // slowpath, not in-bounds vector.transfer or linalg.copy.
473/// memref.cast %alloc: memref<B...> to compatibleMemRefType
474/// scf.yield %4 : compatibleMemRefType, index, index
475// }
476/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
477/// ```
478/// where `alloc` is a top of the function alloca'ed buffer of one vector.
479///
480/// For vector.transfer_write:
481/// There are 2 conditional blocks. First a block to decide which memref and
482/// indices to use for an unmasked, inbounds write. Then a conditional block to
483/// further copy a partial buffer into the final result in the slow path case.
484///
485/// Example (a 2-D vector.transfer_write):
486/// ```
487/// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
488/// ```
489/// is transformed into:
490/// ```
491/// %1:3 = scf.if (%inBounds) {
492/// memref.cast %A: memref<A...> to compatibleMemRefType
493/// scf.yield %view : compatibleMemRefType, index, index
494/// } else {
495/// memref.cast %alloc: memref<B...> to compatibleMemRefType
496/// scf.yield %4 : compatibleMemRefType, index, index
497/// }
498/// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
499/// true]}
500/// scf.if (%notInBounds) {
501/// // slowpath: not in-bounds vector.transfer or linalg.copy.
502/// }
503/// ```
504/// where `alloc` is a top of the function alloca'ed buffer of one vector.
505///
506/// Preconditions:
507/// 1. `xferOp.getPermutationMap()` must be a minor identity map
508/// 2. the rank of the `xferOp.getBase()` and the rank of the
509/// `xferOp.getVector()` must be equal. This will be relaxed in the future
510/// but requires rank-reducing subviews.
511LogicalResult mlir::vector::splitFullAndPartialTransfer(
512 RewriterBase &b, VectorTransferOpInterface xferOp,
513 VectorTransformsOptions options, scf::IfOp *ifOp) {
514 if (options.vectorTransferSplit == VectorTransferSplit::None)
515 return failure();
516
517 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
518 auto inBoundsAttr = b.getBoolArrayAttr(bools);
519 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
520 b.modifyOpInPlace(xferOp, [&]() {
521 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
522 });
523 return success();
524 }
525
526 // Assert preconditions. Additionally, keep the variables in an inner scope to
527 // ensure they aren't used in the wrong scopes further down.
528 {
529 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
530 "Expected splitFullAndPartialTransferPrecondition to hold");
531
532 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
533 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
534
535 if (!(xferReadOp || xferWriteOp))
536 return failure();
537 if (xferWriteOp && xferWriteOp.getMask())
538 return failure();
539 if (xferReadOp && xferReadOp.getMask())
540 return failure();
541 }
542
544 b.setInsertionPoint(xferOp);
545 Value inBoundsCond = createInBoundsCond(
546 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
547 if (!inBoundsCond)
548 return failure();
549
550 // Top of the function `alloc` for transient storage.
551 Value alloc;
552 {
554 Operation *scope = getAutomaticAllocationScope(xferOp);
555 assert(scope->getNumRegions() == 1 &&
556 "AutomaticAllocationScope with >1 regions");
557 b.setInsertionPointToStart(&scope->getRegion(0).front());
558 auto shape = xferOp.getVectorType().getShape();
559 Type elementType = xferOp.getVectorType().getElementType();
560 alloc = memref::AllocaOp::create(b, scope->getLoc(),
561 MemRefType::get(shape, elementType),
562 ValueRange{}, b.getI64IntegerAttr(32));
563 }
564
565 MemRefType compatibleMemRefType =
566 getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
567 cast<MemRefType>(alloc.getType()));
568 if (!compatibleMemRefType)
569 return failure();
570
571 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
572 b.getIndexType());
573 returnTypes[0] = compatibleMemRefType;
574
575 if (auto xferReadOp =
576 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
577 // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
578 scf::IfOp fullPartialIfOp =
579 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
580 ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
581 inBoundsCond,
582 compatibleMemRefType, alloc)
583 : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
584 inBoundsCond, compatibleMemRefType,
585 alloc);
586 if (ifOp)
587 *ifOp = fullPartialIfOp;
588
589 // Set existing read op to in-bounds, it always reads from a full buffer.
590 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
591 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
592
593 b.modifyOpInPlace(xferOp, [&]() {
594 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
595 });
596
597 return success();
598 }
599
600 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
601
602 // Decide which location to write the entire vector to.
603 auto memrefAndIndices = getLocationToWriteFullVec(
604 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
605
606 // Do an in bounds write to either the output or the extra allocated buffer.
607 // The operation is cloned to prevent deleting information needed for the
608 // later IR creation.
609 IRMapping mapping;
610 mapping.map(xferWriteOp.getBase(), memrefAndIndices.front());
611 mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
612 auto *clone = b.clone(*xferWriteOp, mapping);
613 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
614
615 // Create a potential copy from the allocated buffer to the final output in
616 // the slow path case.
617 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
618 createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
619 else
620 createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
621
622 b.eraseOp(xferOp);
623
624 return success();
625}
626
627namespace {
628/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
629/// may take an extra filter to perform selection at a finer granularity.
630struct VectorTransferFullPartialRewriter : public RewritePattern {
631 using FilterConstraintType =
632 std::function<LogicalResult(VectorTransferOpInterface op)>;
633
634 explicit VectorTransferFullPartialRewriter(
635 MLIRContext *context,
636 VectorTransformsOptions options = VectorTransformsOptions(),
637 FilterConstraintType filter =
638 [](VectorTransferOpInterface op) { return success(); },
639 PatternBenefit benefit = 1)
640 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
641 filter(std::move(filter)) {}
642
643 /// Performs the rewrite.
644 LogicalResult matchAndRewrite(Operation *op,
645 PatternRewriter &rewriter) const override;
646
647private:
648 VectorTransformsOptions options;
649 FilterConstraintType filter;
650};
651
652} // namespace
653
654LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
655 Operation *op, PatternRewriter &rewriter) const {
656 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
657 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
658 failed(filter(xferOp)))
659 return failure();
660 return splitFullAndPartialTransfer(rewriter, xferOp, options);
661}
662
665 patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
666 options);
667}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
Base type for affine expression.
Definition AffineExpr.h:68
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Block & front()
Definition Region.h:65
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:70
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Structure to control the behavior of vector transform patterns.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.