MLIR 22.0.0git
VectorTransferSplitRewritePatterns.cpp
Go to the documentation of this file.
1//===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent patterns to rewrite a vector.transfer
10// op into a fully in-bounds part and a partial part.
11//
12//===----------------------------------------------------------------------===//
13
14#include <optional>
15
22
26
27#include "llvm/ADT/STLExtras.h"
28
29#define DEBUG_TYPE "vector-transfer-split"
30
31using namespace mlir;
32using namespace mlir::vector;
33
34/// Build the condition to ensure that a particular VectorTransferOpInterface
35/// is in-bounds.
37 VectorTransferOpInterface xferOp) {
38 assert(xferOp.getPermutationMap().isMinorIdentity() &&
39 "Expected minor identity map");
40 Value inBoundsCond;
41 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
42 // Zip over the resulting vector shape and memref indices.
43 // If the dimension is known to be in-bounds, it does not participate in
44 // the construction of `inBoundsCond`.
45 if (xferOp.isDimInBounds(resultIdx))
46 return;
47 // Fold or create the check that `index + vector_size` <= `memref_size`.
48 Location loc = xferOp.getLoc();
49 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
51 b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
52 {xferOp.getIndices()[indicesIdx]});
53 OpFoldResult dimSz =
54 memref::getMixedSize(b, loc, xferOp.getBase(), indicesIdx);
55 auto maybeCstSum = getConstantIntValue(sum);
56 auto maybeCstDimSz = getConstantIntValue(dimSz);
57 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
58 return;
59 Value cond =
60 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sle,
63 // Conjunction over all dims for which we are in-bounds.
64 if (inBoundsCond)
65 inBoundsCond = arith::AndIOp::create(b, loc, inBoundsCond, cond);
66 else
67 inBoundsCond = cond;
68 });
69 return inBoundsCond;
70}
71
72/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
73/// masking) fast path and a slow path.
74/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
75/// newly created conditional upon function return.
76/// To accommodate for the fact that the original vector.transfer indexing may
77/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
78/// scf.if op returns a view and values of type index.
79/// At this time, only vector.transfer_read case is implemented.
80///
81/// Example (a 2-D vector.transfer_read):
82/// ```
83/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
84/// ```
85/// is transformed into:
86/// ```
87/// %1:3 = scf.if (%inBounds) {
88/// // fast path, direct cast
89/// memref.cast %A: memref<A...> to compatibleMemRefType
90/// scf.yield %view : compatibleMemRefType, index, index
91/// } else {
92/// // slow path, not in-bounds vector.transfer or linalg.copy.
93/// memref.cast %alloc: memref<B...> to compatibleMemRefType
94/// scf.yield %4 : compatibleMemRefType, index, index
95// }
96/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
97/// ```
98/// where `alloc` is a top of the function alloca'ed buffer of one vector.
99///
100/// Preconditions:
101/// 1. `xferOp.getPermutationMap()` must be a minor identity map
102/// 2. the rank of the `xferOp.memref()` and the rank of the
103/// `xferOp.getVector()` must be equal. This will be relaxed in the future
104/// but requires rank-reducing subviews.
105static LogicalResult
106splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
107 // TODO: support 0-d corner case.
108 if (xferOp.getTransferRank() == 0)
109 return failure();
110
111 // TODO: expand support to these 2 cases.
112 if (!xferOp.getPermutationMap().isMinorIdentity())
113 return failure();
114 // Must have some out-of-bounds dimension to be a candidate for splitting.
115 if (!xferOp.hasOutOfBoundsDim())
116 return failure();
117 // Don't split transfer operations directly under IfOp, this avoids applying
118 // the pattern recursively.
119 // TODO: improve the filtering condition to make it more applicable.
120 if (isa<scf::IfOp>(xferOp->getParentOp()))
121 return failure();
122 return success();
123}
124
125/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
126/// be cast. If the MemRefTypes don't have the same rank or are not strided,
127/// return null; otherwise:
128/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
129/// 2. else return a new MemRefType obtained by iterating over the shape and
130/// strides and:
131/// a. keeping the ones that are static and equal across `aT` and `bT`.
132/// b. using a dynamic shape and/or stride for the dimensions that don't
133/// agree.
134static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
135 if (memref::CastOp::areCastCompatible(aT, bT))
136 return aT;
137 if (aT.getRank() != bT.getRank())
138 return MemRefType();
139 int64_t aOffset, bOffset;
140 SmallVector<int64_t, 4> aStrides, bStrides;
141 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
142 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
143 aStrides.size() != bStrides.size())
144 return MemRefType();
145
146 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
147 int64_t resOffset;
148 SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
149 resStrides(bT.getRank(), 0);
150 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
151 resShape[idx] =
152 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
153 resStrides[idx] =
154 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
155 }
156 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
157 return MemRefType::get(
158 resShape, aT.getElementType(),
159 StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
160}
161
162/// Casts the given memref to a compatible memref type. If the source memref has
163/// a different address space than the target type, a `memref.memory_space_cast`
164/// is first inserted, followed by a `memref.cast`.
165static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
166 MemRefType compatibleMemRefType) {
167 MemRefType sourceType = cast<MemRefType>(memref.getType());
168 Value res = memref;
169 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
170 sourceType = MemRefType::get(
171 sourceType.getShape(), sourceType.getElementType(),
172 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
173 res =
174 memref::MemorySpaceCastOp::create(b, memref.getLoc(), sourceType, res);
175 }
176 if (sourceType == compatibleMemRefType)
177 return res;
178 return memref::CastOp::create(b, memref.getLoc(), compatibleMemRefType, res);
179}
180
181/// Operates under a scoped context to build the intersection between the
182/// view `xferOp.getbase()` @ `xferOp.getIndices()` and the view `alloc`.
183// TODO: view intersection/union/differences should be a proper std op.
184static std::pair<Value, Value>
185createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
186 Value alloc) {
187 Location loc = xferOp.getLoc();
188 int64_t memrefRank = xferOp.getShapedType().getRank();
189 // TODO: relax this precondition, will require rank-reducing subviews.
190 assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
191 "Expected memref rank to match the alloc rank");
192 ValueRange leadingIndices =
193 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
195 sizes.append(leadingIndices.begin(), leadingIndices.end());
196 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
197 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
198 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
199 Value dimMemRef =
200 memref::DimOp::create(b, xferOp.getLoc(), xferOp.getBase(), indicesIdx);
201 Value dimAlloc = memref::DimOp::create(b, loc, alloc, resultIdx);
202 Value index = xferOp.getIndices()[indicesIdx];
203 AffineExpr i, j, k;
204 bindDims(xferOp.getContext(), i, j, k);
206 AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
207 // affine_min(%dimMemRef - %index, %dimAlloc)
208 Value affineMin =
209 affine::AffineMinOp::create(b, loc, index.getType(), maps[0],
210 ValueRange{dimMemRef, index, dimAlloc});
211 sizes.push_back(affineMin);
212 });
213
214 SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
215 xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
216 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
217 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
218 auto copySrc = memref::SubViewOp::create(
219 b, loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
220 auto copyDest = memref::SubViewOp::create(
221 b, loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
222 return std::make_pair(copySrc, copyDest);
223}
224
225/// Given an `xferOp` for which:
226/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
227/// 2. a memref of single vector `alloc` has been allocated.
228/// Produce IR resembling:
229/// ```
230/// %1:3 = scf.if (%inBounds) {
231/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
232/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
233/// scf.yield %view, ... : compatibleMemRefType, index, index
234/// } else {
235/// %2 = linalg.fill(%pad, %alloc)
236/// %3 = subview %view [...][...][...]
237/// %4 = subview %alloc [0, 0] [...] [...]
238/// linalg.copy(%3, %4)
239/// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
240/// scf.yield %5, ... : compatibleMemRefType, index, index
241/// }
242/// ```
243/// Return the produced scf::IfOp.
244static scf::IfOp
245createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
246 TypeRange returnTypes, Value inBoundsCond,
247 MemRefType compatibleMemRefType, Value alloc) {
248 Location loc = xferOp.getLoc();
249 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
250 Value memref = xferOp.getBase();
251 return scf::IfOp::create(
252 b, loc, inBoundsCond,
253 [&](OpBuilder &b, Location loc) {
254 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
255 scf::ValueVector viewAndIndices{res};
256 llvm::append_range(viewAndIndices, xferOp.getIndices());
257 scf::YieldOp::create(b, loc, viewAndIndices);
258 },
259 [&](OpBuilder &b, Location loc) {
260 linalg::FillOp::create(b, loc, ValueRange{xferOp.getPadding()},
261 ValueRange{alloc});
262 // Take partial subview of memref which guarantees no dimension
263 // overflows.
264 IRRewriter rewriter(b);
265 std::pair<Value, Value> copyArgs = createSubViewIntersection(
266 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
267 alloc);
268 memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
269 Value casted =
270 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
271 scf::ValueVector viewAndIndices{casted};
272 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
273 zero);
274 scf::YieldOp::create(b, loc, viewAndIndices);
275 });
276}
277
278/// Given an `xferOp` for which:
279/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
280/// 2. a memref of single vector `alloc` has been allocated.
281/// Produce IR resembling:
282/// ```
283/// %1:3 = scf.if (%inBounds) {
284/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
285/// memref.cast %A: memref<A...> to compatibleMemRefType
286/// scf.yield %view, ... : compatibleMemRefType, index, index
287/// } else {
288/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
289/// %3 = vector.type_cast %extra_alloc :
290/// memref<...> to memref<vector<...>>
291/// store %2, %3[] : memref<vector<...>>
292/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
293/// scf.yield %4, ... : compatibleMemRefType, index, index
294/// }
295/// ```
296/// Return the produced scf::IfOp.
297static scf::IfOp createFullPartialVectorTransferRead(
298 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
299 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
300 Location loc = xferOp.getLoc();
301 scf::IfOp fullPartialIfOp;
302 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
303 Value memref = xferOp.getBase();
304 return scf::IfOp::create(
305 b, loc, inBoundsCond,
306 [&](OpBuilder &b, Location loc) {
307 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
308 scf::ValueVector viewAndIndices{res};
309 llvm::append_range(viewAndIndices, xferOp.getIndices());
310 scf::YieldOp::create(b, loc, viewAndIndices);
311 },
312 [&](OpBuilder &b, Location loc) {
313 Operation *newXfer = b.clone(*xferOp.getOperation());
314 Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
315 memref::StoreOp::create(
316 b, loc, vector,
317 vector::TypeCastOp::create(
318 b, loc, MemRefType::get({}, vector.getType()), alloc));
319
320 Value casted =
321 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
322 scf::ValueVector viewAndIndices{casted};
323 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
324 zero);
325 scf::YieldOp::create(b, loc, viewAndIndices);
326 });
327}
328
329/// Given an `xferOp` for which:
330/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
331/// 2. a memref of single vector `alloc` has been allocated.
332/// Produce IR resembling:
333/// ```
334/// %1:3 = scf.if (%inBounds) {
335/// memref.cast %A: memref<A...> to compatibleMemRefType
336/// scf.yield %view, ... : compatibleMemRefType, index, index
337/// } else {
338/// %3 = vector.type_cast %extra_alloc :
339/// memref<...> to memref<vector<...>>
340/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
341/// scf.yield %4, ... : compatibleMemRefType, index, index
342/// }
343/// ```
344static ValueRange
345getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
346 TypeRange returnTypes, Value inBoundsCond,
347 MemRefType compatibleMemRefType, Value alloc) {
348 Location loc = xferOp.getLoc();
349 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
350 Value memref = xferOp.getBase();
351 return scf::IfOp::create(
352 b, loc, inBoundsCond,
353 [&](OpBuilder &b, Location loc) {
354 Value res =
355 castToCompatibleMemRefType(b, memref, compatibleMemRefType);
356 scf::ValueVector viewAndIndices{res};
357 llvm::append_range(viewAndIndices, xferOp.getIndices());
358 scf::YieldOp::create(b, loc, viewAndIndices);
359 },
360 [&](OpBuilder &b, Location loc) {
361 Value casted =
362 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
363 scf::ValueVector viewAndIndices{casted};
364 viewAndIndices.insert(viewAndIndices.end(),
365 xferOp.getTransferRank(), zero);
366 scf::YieldOp::create(b, loc, viewAndIndices);
367 })
368 ->getResults();
369}
370
371/// Given an `xferOp` for which:
372/// 1. `inBoundsCond` has been computed.
373/// 2. a memref of single vector `alloc` has been allocated.
374/// 3. it originally wrote to %view
375/// Produce IR resembling:
376/// ```
377/// %notInBounds = arith.xori %inBounds, %true
378/// scf.if (%notInBounds) {
379/// %3 = subview %alloc [...][...][...]
380/// %4 = subview %view [0, 0][...][...]
381/// linalg.copy(%3, %4)
382/// }
383/// ```
384static void createFullPartialLinalgCopy(RewriterBase &b,
385 vector::TransferWriteOp xferOp,
386 Value inBoundsCond, Value alloc) {
387 Location loc = xferOp.getLoc();
388 auto notInBounds = arith::XOrIOp::create(
389 b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1));
390 scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) {
391 IRRewriter rewriter(b);
392 std::pair<Value, Value> copyArgs = createSubViewIntersection(
393 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
394 alloc);
395 memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
396 scf::YieldOp::create(b, loc, ValueRange{});
397 });
398}
399
400/// Given an `xferOp` for which:
401/// 1. `inBoundsCond` has been computed.
402/// 2. a memref of single vector `alloc` has been allocated.
403/// 3. it originally wrote to %view
404/// Produce IR resembling:
405/// ```
406/// %notInBounds = arith.xori %inBounds, %true
407/// scf.if (%notInBounds) {
408/// %2 = load %alloc : memref<vector<...>>
409/// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
410/// }
411/// ```
412static void createFullPartialVectorTransferWrite(RewriterBase &b,
413 vector::TransferWriteOp xferOp,
414 Value inBoundsCond,
415 Value alloc) {
416 Location loc = xferOp.getLoc();
417 auto notInBounds = arith::XOrIOp::create(
418 b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1));
419 scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) {
420 IRMapping mapping;
421 Value load = memref::LoadOp::create(
422 b, loc,
423 vector::TypeCastOp::create(
424 b, loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
425 ValueRange());
426 mapping.map(xferOp.getVector(), load);
427 b.clone(*xferOp.getOperation(), mapping);
428 scf::YieldOp::create(b, loc, ValueRange{});
429 });
430}
431
432// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
433static Operation *getAutomaticAllocationScope(Operation *op) {
434 // Find the closest surrounding allocation scope that is not a known looping
435 // construct (putting alloca's in loops doesn't always lower to deallocation
436 // until the end of the loop).
437 Operation *scope = nullptr;
438 for (Operation *parent = op->getParentOp(); parent != nullptr;
439 parent = parent->getParentOp()) {
440 if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
441 scope = parent;
442 if (!isa<scf::ForOp, affine::AffineForOp>(parent))
443 break;
444 }
445 assert(scope && "Expected op to be inside automatic allocation scope");
446 return scope;
447}
448
449/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
450/// masking) fastpath and a slowpath.
451///
452/// For vector.transfer_read:
453/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
454/// newly created conditional upon function return.
455/// To accomodate for the fact that the original vector.transfer indexing may be
456/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
457/// scf.if op returns a view and values of type index.
458///
459/// Example (a 2-D vector.transfer_read):
460/// ```
461/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
462/// ```
463/// is transformed into:
464/// ```
465/// %1:3 = scf.if (%inBounds) {
466/// // fastpath, direct cast
467/// memref.cast %A: memref<A...> to compatibleMemRefType
468/// scf.yield %view : compatibleMemRefType, index, index
469/// } else {
470/// // slowpath, not in-bounds vector.transfer or linalg.copy.
471/// memref.cast %alloc: memref<B...> to compatibleMemRefType
472/// scf.yield %4 : compatibleMemRefType, index, index
473// }
474/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
475/// ```
476/// where `alloc` is a top of the function alloca'ed buffer of one vector.
477///
478/// For vector.transfer_write:
479/// There are 2 conditional blocks. First a block to decide which memref and
480/// indices to use for an unmasked, inbounds write. Then a conditional block to
481/// further copy a partial buffer into the final result in the slow path case.
482///
483/// Example (a 2-D vector.transfer_write):
484/// ```
485/// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
486/// ```
487/// is transformed into:
488/// ```
489/// %1:3 = scf.if (%inBounds) {
490/// memref.cast %A: memref<A...> to compatibleMemRefType
491/// scf.yield %view : compatibleMemRefType, index, index
492/// } else {
493/// memref.cast %alloc: memref<B...> to compatibleMemRefType
494/// scf.yield %4 : compatibleMemRefType, index, index
495/// }
496/// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
497/// true]}
498/// scf.if (%notInBounds) {
499/// // slowpath: not in-bounds vector.transfer or linalg.copy.
500/// }
501/// ```
502/// where `alloc` is a top of the function alloca'ed buffer of one vector.
503///
504/// Preconditions:
505/// 1. `xferOp.getPermutationMap()` must be a minor identity map
506/// 2. the rank of the `xferOp.getBase()` and the rank of the
507/// `xferOp.getVector()` must be equal. This will be relaxed in the future
508/// but requires rank-reducing subviews.
509LogicalResult mlir::vector::splitFullAndPartialTransfer(
510 RewriterBase &b, VectorTransferOpInterface xferOp,
511 VectorTransformsOptions options, scf::IfOp *ifOp) {
512 if (options.vectorTransferSplit == VectorTransferSplit::None)
513 return failure();
514
515 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
516 auto inBoundsAttr = b.getBoolArrayAttr(bools);
517 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
518 b.modifyOpInPlace(xferOp, [&]() {
519 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
520 });
521 return success();
522 }
523
524 // Assert preconditions. Additionally, keep the variables in an inner scope to
525 // ensure they aren't used in the wrong scopes further down.
526 {
527 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
528 "Expected splitFullAndPartialTransferPrecondition to hold");
529
530 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
531 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
532
533 if (!(xferReadOp || xferWriteOp))
534 return failure();
535 if (xferWriteOp && xferWriteOp.getMask())
536 return failure();
537 if (xferReadOp && xferReadOp.getMask())
538 return failure();
539 }
540
542 b.setInsertionPoint(xferOp);
543 Value inBoundsCond = createInBoundsCond(
544 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
545 if (!inBoundsCond)
546 return failure();
547
548 // Top of the function `alloc` for transient storage.
549 Value alloc;
550 {
552 Operation *scope = getAutomaticAllocationScope(xferOp);
553 assert(scope->getNumRegions() == 1 &&
554 "AutomaticAllocationScope with >1 regions");
555 b.setInsertionPointToStart(&scope->getRegion(0).front());
556 auto shape = xferOp.getVectorType().getShape();
557 Type elementType = xferOp.getVectorType().getElementType();
558 alloc = memref::AllocaOp::create(b, scope->getLoc(),
559 MemRefType::get(shape, elementType),
560 ValueRange{}, b.getI64IntegerAttr(32));
561 }
562
563 MemRefType compatibleMemRefType =
564 getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
565 cast<MemRefType>(alloc.getType()));
566 if (!compatibleMemRefType)
567 return failure();
568
569 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
570 b.getIndexType());
571 returnTypes[0] = compatibleMemRefType;
572
573 if (auto xferReadOp =
574 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
575 // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
576 scf::IfOp fullPartialIfOp =
577 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
578 ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
579 inBoundsCond,
580 compatibleMemRefType, alloc)
581 : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
582 inBoundsCond, compatibleMemRefType,
583 alloc);
584 if (ifOp)
585 *ifOp = fullPartialIfOp;
586
587 // Set existing read op to in-bounds, it always reads from a full buffer.
588 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
589 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
590
591 b.modifyOpInPlace(xferOp, [&]() {
592 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
593 });
594
595 return success();
596 }
597
598 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
599
600 // Decide which location to write the entire vector to.
601 auto memrefAndIndices = getLocationToWriteFullVec(
602 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
603
604 // Do an in bounds write to either the output or the extra allocated buffer.
605 // The operation is cloned to prevent deleting information needed for the
606 // later IR creation.
607 IRMapping mapping;
608 mapping.map(xferWriteOp.getBase(), memrefAndIndices.front());
609 mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
610 auto *clone = b.clone(*xferWriteOp, mapping);
611 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
612
613 // Create a potential copy from the allocated buffer to the final output in
614 // the slow path case.
615 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
616 createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
617 else
618 createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
619
620 b.eraseOp(xferOp);
621
622 return success();
623}
624
625namespace {
626/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
627/// may take an extra filter to perform selection at a finer granularity.
628struct VectorTransferFullPartialRewriter : public RewritePattern {
629 using FilterConstraintType =
630 std::function<LogicalResult(VectorTransferOpInterface op)>;
631
632 explicit VectorTransferFullPartialRewriter(
633 MLIRContext *context,
634 VectorTransformsOptions options = VectorTransformsOptions(),
635 FilterConstraintType filter =
636 [](VectorTransferOpInterface op) { return success(); },
637 PatternBenefit benefit = 1)
638 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
639 filter(std::move(filter)) {}
640
641 /// Performs the rewrite.
642 LogicalResult matchAndRewrite(Operation *op,
643 PatternRewriter &rewriter) const override;
644
645private:
646 VectorTransformsOptions options;
647 FilterConstraintType filter;
648};
649
650} // namespace
651
652LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
653 Operation *op, PatternRewriter &rewriter) const {
654 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
655 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
656 failed(filter(xferOp)))
657 return failure();
658 return splitFullAndPartialTransfer(rewriter, xferOp, options);
659}
660
663 patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
664 options);
665}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
Base type for affine expression.
Definition AffineExpr.h:68
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Block & front()
Definition Region.h:65
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:68
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Structure to control the behavior of vector transform patterns.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.