MLIR 23.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1//===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with optimizing transfer_read and
10// transfer_write ops.
11//
12//===----------------------------------------------------------------------===//
13
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/Matchers.h"
27#include "mlir/IR/Operation.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/StringRef.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/DebugLog.h"
33#include "llvm/Support/LogicalResult.h"
34
35#define DEBUG_TYPE "vector-transfer-opt"
36
37using namespace mlir;
38
39/// Return the ancestor op in the region or nullptr if the region is not
40/// an ancestor of the op.
42 LDBG() << " Finding ancestor of " << *op << " in region";
43 for (; op != nullptr && op->getParentRegion() != region;
44 op = op->getParentOp())
45 ;
46 if (op) {
47 LDBG() << " -> Ancestor: " << *op;
48 } else {
49 LDBG() << " -> Ancestor: nullptr";
50 }
51 return op;
52}
53
54namespace {
55
56class TransferOptimization {
57public:
58 TransferOptimization(RewriterBase &rewriter, Operation *op)
59 : rewriter(rewriter), dominators(op), postDominators(op) {}
60 void deadStoreOp(vector::TransferWriteOp);
61 void storeToLoadForwarding(vector::TransferReadOp);
62 void removeDeadOp() {
63 LDBG() << "Removing " << opToErase.size() << " dead operations";
64 for (Operation *op : opToErase) {
65 LDBG() << " -> Erasing: " << *op;
66 rewriter.eraseOp(op);
67 }
68 opToErase.clear();
69 }
70
71private:
72 RewriterBase &rewriter;
73 bool isReachable(Operation *start, Operation *dest);
74 DominanceInfo dominators;
75 PostDominanceInfo postDominators;
76 std::vector<Operation *> opToErase;
77};
78
79} // namespace
80/// Return true if there is a path from start operation to dest operation,
81/// otherwise return false. The operations have to be in the same region.
82bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
83 LDBG() << " Checking reachability from " << *start << " to " << *dest;
84 assert(start->getParentRegion() == dest->getParentRegion() &&
85 "This function only works for ops i the same region");
86 // Simple case where the start op dominate the destination.
87 if (dominators.dominates(start, dest)) {
88 LDBG() << " -> Start dominates dest, reachable";
89 return true;
90 }
91 bool blockReachable = start->getBlock()->isReachable(dest->getBlock());
92 LDBG() << " -> Block reachable: " << blockReachable;
93 return blockReachable;
94}
95
96/// For transfer_write to overwrite fully another transfer_write must:
97/// 1. Access the same memref with the same indices and vector type.
98/// 2. Post-dominate the other transfer_write operation.
99/// If several candidates are available, one must be post-dominated by all the
100/// others since they are all post-dominating the same transfer_write. We only
101/// consider the transfer_write post-dominated by all the other candidates as
102/// this will be the first transfer_write executed after the potentially dead
103/// transfer_write.
104/// If we found such an overwriting transfer_write we know that the original
105/// transfer_write is dead if all reads that can be reached from the potentially
106/// dead transfer_write are dominated by the overwriting transfer_write.
107void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
108 LDBG() << "=== Starting deadStoreOp analysis for: " << *write.getOperation();
109 llvm::SmallVector<Operation *, 8> blockingAccesses;
110 Operation *firstOverwriteCandidate = nullptr;
111 Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
112 LDBG() << "Source memref (after skipping view-like ops): " << source;
113 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
114 source.getUsers().end());
115 LDBG() << "Found " << users.size() << " users of source memref";
116 llvm::SmallDenseSet<Operation *, 32> processed;
117 while (!users.empty()) {
118 Operation *user = users.pop_back_val();
119 LDBG() << "Processing user: " << *user;
120 // If the user has already been processed skip.
121 if (!processed.insert(user).second) {
122 LDBG() << " -> Already processed, skipping";
123 continue;
124 }
125 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
126 LDBG() << " -> View-like operation, following to destination";
127 Value viewDest = viewLike.getViewDest();
128 users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
129 continue;
130 }
131 if (isMemoryEffectFree(user)) {
132 LDBG() << " -> Memory effect free, skipping";
133 continue;
134 }
135 if (user == write.getOperation()) {
136 LDBG() << " -> Same as write operation, skipping";
137 continue;
138 }
139 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
140 LDBG() << " -> Found transfer_write candidate: " << *nextWrite;
141 // Check candidate that can override the store.
142 bool sameView = memref::isSameViewOrTrivialAlias(
143 cast<MemrefValue>(nextWrite.getBase()),
144 cast<MemrefValue>(write.getBase()));
145 bool sameValue = checkSameValueWAW(nextWrite, write);
146 bool postDominates = postDominators.postDominates(nextWrite, write);
147 LDBG() << " -> Same view: " << sameView
148 << ", Same value: " << sameValue
149 << ", Post-dominates: " << postDominates;
150
151 if (sameView && sameValue && postDominates) {
152 LDBG() << " -> Valid overwrite candidate found";
153 if (firstOverwriteCandidate == nullptr ||
154 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) {
155 LDBG() << " -> New first overwrite candidate: " << *nextWrite;
156 firstOverwriteCandidate = nextWrite;
157 } else {
158 LDBG() << " -> Keeping existing first overwrite candidate";
159 assert(
160 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
161 }
162 continue;
163 }
164 LDBG() << " -> Not a valid overwrite candidate";
165 }
166 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
167 LDBG() << " -> Found vector transfer operation: " << *transferOp;
168 // Don't need to consider disjoint accesses.
169 bool isDisjoint = vector::isDisjointTransferSet(
170 cast<VectorTransferOpInterface>(write.getOperation()),
171 cast<VectorTransferOpInterface>(transferOp.getOperation()),
172 /*testDynamicValueUsingBounds=*/true);
173 LDBG() << " -> Is disjoint: " << isDisjoint;
174 if (isDisjoint) {
175 LDBG() << " -> Skipping disjoint access";
176 continue;
177 }
178 }
179 LDBG() << " -> Adding to blocking accesses: " << *user;
180 blockingAccesses.push_back(user);
181 }
182 LDBG() << "Finished processing users. Found " << blockingAccesses.size()
183 << " blocking accesses";
184
185 if (firstOverwriteCandidate == nullptr) {
186 LDBG() << "No overwrite candidate found, store is not dead";
187 return;
188 }
189
190 LDBG() << "First overwrite candidate: " << *firstOverwriteCandidate;
191 Region *topRegion = firstOverwriteCandidate->getParentRegion();
192 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
193 assert(writeAncestor &&
194 "write op should be recursively part of the top region");
195 LDBG() << "Write ancestor in top region: " << *writeAncestor;
196
197 LDBG() << "Checking " << blockingAccesses.size()
198 << " blocking accesses for reachability";
199 for (Operation *access : blockingAccesses) {
200 LDBG() << "Checking blocking access: " << *access;
201 Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
202 // TODO: if the access and write have the same ancestor we could recurse in
203 // the region to know if the access is reachable with more precision.
204 if (accessAncestor == nullptr) {
205 LDBG() << " -> No ancestor in top region, skipping";
206 continue;
207 }
208
209 bool isReachableFromWrite = isReachable(writeAncestor, accessAncestor);
210 LDBG() << " -> Is reachable from write: " << isReachableFromWrite;
211 if (!isReachableFromWrite) {
212 LDBG() << " -> Not reachable, skipping";
213 continue;
214 }
215
216 bool overwriteDominatesAccess =
217 dominators.dominates(firstOverwriteCandidate, accessAncestor);
218 LDBG() << " -> Overwrite dominates access: " << overwriteDominatesAccess;
219 if (!overwriteDominatesAccess) {
220 LDBG() << "Store may not be dead due to op: " << *accessAncestor;
221 return;
222 }
223 LDBG() << " -> Access is dominated by overwrite, continuing";
224 }
225 LDBG() << "Found dead store: " << *write.getOperation()
226 << " overwritten by: " << *firstOverwriteCandidate;
227 opToErase.push_back(write.getOperation());
228}
229
230/// A transfer_write candidate to storeToLoad forwarding must:
231/// 1. Access the same memref with the same indices and vector type as the
232/// transfer_read.
233/// 2. Dominate the transfer_read operation.
234/// If several candidates are available, one must be dominated by all the others
235/// since they are all dominating the same transfer_read. We only consider the
236/// transfer_write dominated by all the other candidates as this will be the
237/// last transfer_write executed before the transfer_read.
238/// If we found such a candidate we can do the forwarding if all the other
239/// potentially aliasing ops that may reach the transfer_read are post-dominated
240/// by the transfer_write.
241void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
242 LDBG() << "=== Starting storeToLoadForwarding analysis for: "
243 << *read.getOperation();
244 if (read.hasOutOfBoundsDim()) {
245 LDBG() << "Read has out-of-bounds dimensions, skipping";
246 return;
247 }
248 SmallVector<Operation *, 8> blockingWrites;
249 vector::TransferWriteOp lastwrite = nullptr;
250 Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
251 LDBG() << "Source memref (after skipping view-like ops): " << source;
252 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
253 source.getUsers().end());
254 LDBG() << "Found " << users.size() << " users of source memref";
255 llvm::SmallDenseSet<Operation *, 32> processed;
256 while (!users.empty()) {
257 Operation *user = users.pop_back_val();
258 LDBG() << "Processing user: " << *user;
259 // If the user has already been processed skip.
260 if (!processed.insert(user).second) {
261 LDBG() << " -> Already processed, skipping";
262 continue;
263 }
264 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
265 LDBG() << " -> View-like operation, following to destination";
266 Value viewDest = viewLike.getViewDest();
267 users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
268 continue;
269 }
270 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) {
271 LDBG() << " -> Memory effect free or transfer_read, skipping";
272 continue;
273 }
274 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
275 LDBG() << " -> Found transfer_write candidate: " << *write;
276 // If there is a write, but we can prove that it is disjoint we can ignore
277 // the write.
278 bool isDisjoint = vector::isDisjointTransferSet(
279 cast<VectorTransferOpInterface>(write.getOperation()),
280 cast<VectorTransferOpInterface>(read.getOperation()),
281 /*testDynamicValueUsingBounds=*/true);
282 LDBG() << " -> Is disjoint: " << isDisjoint;
283 if (isDisjoint) {
284 LDBG() << " -> Skipping disjoint write";
285 continue;
286 }
287
288 bool sameView =
289 memref::isSameViewOrTrivialAlias(cast<MemrefValue>(read.getBase()),
290 cast<MemrefValue>(write.getBase()));
291 bool dominates = dominators.dominates(write, read);
292 bool sameValue = checkSameValueRAW(write, read);
293 LDBG() << " -> Same view: " << sameView << ", Dominates: " << dominates
294 << ", Same value: " << sameValue;
295
296 if (sameView && dominates && sameValue) {
297 LDBG() << " -> Valid forwarding candidate found";
298 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) {
299 LDBG() << " -> New last write candidate: " << *write;
300 lastwrite = write;
301 } else {
302 LDBG() << " -> Keeping existing last write candidate";
303 assert(dominators.dominates(write, lastwrite));
304 }
305 continue;
306 }
307 LDBG() << " -> Not a valid forwarding candidate";
308 }
309 LDBG() << " -> Adding to blocking writes: " << *user;
310 blockingWrites.push_back(user);
311 }
312 LDBG() << "Finished processing users. Found " << blockingWrites.size()
313 << " blocking writes";
314
315 if (lastwrite == nullptr) {
316 LDBG() << "No last write candidate found, cannot forward";
317 return;
318 }
319
320 LDBG() << "Last write candidate: " << *lastwrite;
321 Region *topRegion = lastwrite->getParentRegion();
322 Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
323 assert(readAncestor &&
324 "read op should be recursively part of the top region");
325 LDBG() << "Read ancestor in top region: " << *readAncestor;
326
327 LDBG() << "Checking " << blockingWrites.size()
328 << " blocking writes for post-dominance";
329 for (Operation *write : blockingWrites) {
330 LDBG() << "Checking blocking write: " << *write;
331 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
332 if (writeAncestor) {
333 LDBG() << " -> Write ancestor: " << *writeAncestor;
334 } else {
335 LDBG() << " -> Write ancestor: nullptr";
336 }
337
338 // TODO: if the store and read have the same ancestor we could recurse in
339 // the region to know if the read is reachable with more precision.
340 if (writeAncestor == nullptr) {
341 LDBG() << " -> No ancestor in top region, skipping";
342 continue;
343 }
344
345 bool isReachableToRead = isReachable(writeAncestor, readAncestor);
346 LDBG() << " -> Is reachable to read: " << isReachableToRead;
347 if (!isReachableToRead) {
348 LDBG() << " -> Not reachable, skipping";
349 continue;
350 }
351
352 bool lastWritePostDominates =
353 postDominators.postDominates(lastwrite, write);
354 LDBG() << " -> Last write post-dominates blocking write: "
355 << lastWritePostDominates;
356 if (!lastWritePostDominates) {
357 LDBG() << "Fail to do write to read forwarding due to op: " << *write;
358 return;
359 }
360 LDBG() << " -> Blocking write is post-dominated, continuing";
361 }
362
363 LDBG() << "Forward value from " << *lastwrite.getOperation()
364 << " to: " << *read.getOperation();
365 read.replaceAllUsesWith(lastwrite.getVector());
366 opToErase.push_back(read.getOperation());
367}
368
369/// Converts OpFoldResults to int64_t shape without unit dims.
371 SmallVector<int64_t> reducedShape;
372 for (const auto size : mixedSizes) {
373 if (llvm::dyn_cast_if_present<Value>(size)) {
374 reducedShape.push_back(ShapedType::kDynamic);
375 continue;
376 }
377
378 auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
379 if (value == 1)
380 continue;
381 reducedShape.push_back(value.getSExtValue());
382 }
383 return reducedShape;
384}
385
386/// Drops unit dimensions from the input MemRefType.
387static MemRefType dropUnitDims(MemRefType inputType,
390 ArrayRef<OpFoldResult> strides) {
391 auto targetShape = getReducedShape(sizes);
392 MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
393 targetShape, inputType, offsets, sizes, strides);
394 return rankReducedType.canonicalizeStridedLayout();
395}
396
397/// Creates a rank-reducing memref.subview op that drops unit dims from its
398/// input. Or just returns the input if it was already without unit dims.
400 mlir::Location loc,
401 Value input) {
402 MemRefType inputType = cast<MemRefType>(input.getType());
403 SmallVector<OpFoldResult> offsets(inputType.getRank(),
404 rewriter.getIndexAttr(0));
405 SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
406 SmallVector<OpFoldResult> strides(inputType.getRank(),
407 rewriter.getIndexAttr(1));
408 MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
409
410 if (resultType.canonicalizeStridedLayout() ==
411 inputType.canonicalizeStridedLayout())
412 return input;
413 return memref::SubViewOp::create(rewriter, loc, resultType, input, offsets,
414 sizes, strides);
415}
416
417/// Returns the number of dims that aren't unit dims.
419 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
420}
421
422/// Trims non-scalable one dimensions from `oldType` and returns the result
423/// type.
424static VectorType trimNonScalableUnitDims(VectorType oldType) {
425 SmallVector<int64_t> newShape;
426 SmallVector<bool> newScalableDims;
427 for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
428 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
429 continue;
430 newShape.push_back(dimSize);
431 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
432 }
433 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
434}
435
436static bool isUnitDimMask(Value maskDimSize) {
437 return matchPattern(maskDimSize, m_One());
438}
439static bool isUnitDimMask(int64_t maskDimSize) { return maskDimSize == 1; }
440
441/// Rewrites a mask op to drop non-scalable unit dimensions.
442/// Supports vector.create_mask and vector.constant_mask.
443template <typename MaskOp>
444static FailureOr<Value> maskDropNonScalableUnitDims(PatternRewriter &rewriter,
445 Location loc, MaskOp op) {
446 auto type = op.getType();
447 VectorType reducedType = trimNonScalableUnitDims(type);
448 if (reducedType.getRank() == type.getRank())
449 return failure();
450
451 using ElemType = std::decay_t<decltype(*op.getMaskDimSizes().begin())>;
452 SmallVector<ElemType> reduced;
453 for (auto [dim, dimIsScalable, elem] : llvm::zip_equal(
454 type.getShape(), type.getScalableDims(), op.getMaskDimSizes())) {
455 if (dim == 1 && !dimIsScalable) {
456 if (!isUnitDimMask(elem))
457 return failure();
458 continue;
459 }
460 reduced.push_back(elem);
461 }
462 return MaskOp::create(rewriter, loc, reducedType, reduced).getResult();
463}
464
465namespace {
466
467/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
468/// inserting a memref.subview dropping those unit dims. The vector shapes are
469/// also reduced accordingly.
470class TransferReadDropUnitDimsPattern
471 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
472 using MaskableOpRewritePattern::MaskableOpRewritePattern;
473
474 FailureOr<Value>
475 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
476 vector::MaskingOpInterface maskingOp,
477 PatternRewriter &rewriter) const override {
478 LDBG() << "=== TransferReadDropUnitDimsPattern: Analyzing "
479 << *transferReadOp;
480 auto loc = transferReadOp.getLoc();
481 Value vector = transferReadOp.getVector();
482 VectorType vectorType = cast<VectorType>(vector.getType());
483 Value source = transferReadOp.getBase();
484 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
485 // TODO: support tensor types.
486 if (!sourceType) {
487 LDBG() << " -> Not a MemRefType, skipping";
488 return failure();
489 }
490 // TODO: generalize this pattern, relax the requirements here.
491 if (transferReadOp.hasOutOfBoundsDim()) {
492 LDBG() << " -> Has out-of-bounds dimensions, skipping";
493 return failure();
494 }
495 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
496 LDBG() << " -> Not minor identity permutation map, skipping";
497 return failure();
498 }
499 // Check if the source shape can be further reduced.
500 int reducedRank = getReducedRank(sourceType.getShape());
501 LDBG() << " -> Source rank: " << sourceType.getRank()
502 << ", Reduced rank: " << reducedRank;
503 if (reducedRank == sourceType.getRank()) {
504 LDBG() << " -> No unit dimensions to drop, skipping";
505 return failure();
506 }
507 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
508 // out.
509 if (reducedRank == 0 && maskingOp) {
510 LDBG() << " -> 0-d vector with masking not supported, skipping";
511 return failure();
512 }
513 // Check if the reduced vector shape matches the reduced source shape.
514 // Otherwise, this case is not supported yet.
515 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
516 LDBG() << " -> Vector type: " << vectorType
517 << ", Reduced vector type: " << reducedVectorType;
518 if (reducedRank != reducedVectorType.getRank()) {
519 LDBG() << " -> Reduced ranks don't match, skipping";
520 return failure();
521 }
522 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
523 return getConstantIntValue(v) != static_cast<int64_t>(0);
524 })) {
525 LDBG() << " -> Non-zero indices found, skipping";
526 return failure();
527 }
528
529 Value maskOp = transferReadOp.getMask();
530 if (maskOp) {
531 LDBG() << " -> Processing mask operation";
532 auto maskVectorType = cast<VectorType>(maskOp.getType());
533 FailureOr<Value> rankReducedMaskOp = failure();
534 if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
535 rankReducedMaskOp =
536 maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
537 else if (auto constantMaskOp =
538 maskOp.getDefiningOp<vector::ConstantMaskOp>())
539 rankReducedMaskOp =
540 maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
541 else
542 return rewriter.notifyMatchFailure(
543 transferReadOp,
544 "unsupported mask op, only 'vector.create_mask' and "
545 "'vector.constant_mask' are currently supported");
546
547 if (succeeded(rankReducedMaskOp)) {
548 maskOp = *rankReducedMaskOp;
549 LDBG() << " -> Successfully reduced mask dimensions";
550 } else if (maskVectorType.getRank() != reducedVectorType.getRank()) {
551 return rewriter.notifyMatchFailure(
552 transferReadOp, "Mask reduction required, but failed");
553 }
554 }
555
556 LDBG() << " -> Creating rank-reduced subview and new transfer_read";
557 Value reducedShapeSource =
558 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
559 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
560 Repeated<Value> zeros(reducedRank, c0);
561 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
562 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
563 Operation *newTransferReadOp = vector::TransferReadOp::create(
564 rewriter, loc, reducedVectorType, reducedShapeSource, zeros,
565 identityMap, transferReadOp.getPadding(), maskOp,
566 rewriter.getBoolArrayAttr(inBounds));
567 LDBG() << " -> Created new transfer_read: " << *newTransferReadOp;
568
569 if (maskingOp) {
570 LDBG() << " -> Applying masking operation";
571 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
572 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
573 maskingOp.getMask());
574 newTransferReadOp = mlir::vector::maskOperation(
575 rewriter, newTransferReadOp, shapeCastMask);
576 }
577
578 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
579 loc, vectorType, newTransferReadOp->getResults()[0]);
580 LDBG() << " -> Created shape cast: " << *shapeCast.getDefiningOp();
581 LDBG() << " -> Pattern match successful, returning result";
582
583 return shapeCast;
584 }
585};
586
587/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
588/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
589/// vector shapes are also reduced accordingly.
590class TransferWriteDropUnitDimsPattern
591 : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
592 using MaskableOpRewritePattern::MaskableOpRewritePattern;
593
594 FailureOr<Value>
595 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
596 vector::MaskingOpInterface maskingOp,
597 PatternRewriter &rewriter) const override {
598 LDBG() << "=== TransferWriteDropUnitDimsPattern: Analyzing "
599 << *transferWriteOp;
600 auto loc = transferWriteOp.getLoc();
601 Value vector = transferWriteOp.getVector();
602 VectorType vectorType = cast<VectorType>(vector.getType());
603 Value source = transferWriteOp.getBase();
604 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
605 // TODO: support tensor type.
606 if (!sourceType) {
607 LDBG() << " -> Not a MemRefType, skipping";
608 return failure();
609 }
610 // TODO: generalize this pattern, relax the requirements here.
611 if (transferWriteOp.hasOutOfBoundsDim()) {
612 LDBG() << " -> Has out-of-bounds dimensions, skipping";
613 return failure();
614 }
615 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) {
616 LDBG() << " -> Not minor identity permutation map, skipping";
617 return failure();
618 }
619 // Check if the destination shape can be further reduced.
620 int reducedRank = getReducedRank(sourceType.getShape());
621 LDBG() << " -> Source rank: " << sourceType.getRank()
622 << ", Reduced rank: " << reducedRank;
623 if (reducedRank == sourceType.getRank()) {
624 LDBG() << " -> No unit dimensions to drop, skipping";
625 return failure();
626 }
627 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
628 // out.
629 if (reducedRank == 0 && maskingOp) {
630 LDBG() << " -> 0-d vector with masking not supported, skipping";
631 return failure();
632 }
633 // Check if the reduced vector shape matches the reduced destination shape.
634 // Otherwise, this case is not supported yet.
635 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
636 LDBG() << " -> Vector type: " << vectorType
637 << ", Reduced vector type: " << reducedVectorType;
638 if (reducedRank != reducedVectorType.getRank()) {
639 LDBG() << " -> Reduced ranks don't match, skipping";
640 return failure();
641 }
642 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
643 return getConstantIntValue(v) != static_cast<int64_t>(0);
644 })) {
645 LDBG() << " -> Non-zero indices found, skipping";
646 return failure();
647 }
648
649 Value maskOp = transferWriteOp.getMask();
650 if (maskOp) {
651 LDBG() << " -> Processing mask operation";
652 auto maskVectorType = cast<VectorType>(maskOp.getType());
653 FailureOr<Value> rankReducedMask = failure();
654 if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
655 rankReducedMask =
656 maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
657 else if (auto constantMaskOp =
658 maskOp.getDefiningOp<vector::ConstantMaskOp>())
659 rankReducedMask =
660 maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
661 else
662 return rewriter.notifyMatchFailure(
663 transferWriteOp,
664 "unsupported mask op, only 'vector.create_mask' and "
665 "'vector.constant_mask' are currently supported");
666
667 if (succeeded(rankReducedMask)) {
668 maskOp = *rankReducedMask;
669 LDBG() << " -> Successfully reduced mask dimensions";
670 } else if (maskVectorType.getRank() != reducedVectorType.getRank()) {
671 return rewriter.notifyMatchFailure(
672 transferWriteOp, "Mask reduction required, but failed");
673 }
674 }
675 LDBG() << " -> Creating rank-reduced subview and new transfer_write";
676 Value reducedShapeSource =
677 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
678 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
679 Repeated<Value> zeros(reducedRank, c0);
680 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
681 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
682 auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
683 loc, reducedVectorType, vector);
684 Operation *newXferWrite = vector::TransferWriteOp::create(
685 rewriter, loc, Type(), shapeCastSrc, reducedShapeSource, zeros,
686 identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
687 LDBG() << " -> Created new transfer_write: " << *newXferWrite;
688
689 if (maskingOp) {
690 LDBG() << " -> Applying masking operation";
691 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
692 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
693 maskingOp.getMask());
694 newXferWrite =
695 mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
696 }
697
698 if (transferWriteOp.hasPureTensorSemantics()) {
699 LDBG() << " -> Pattern match successful (tensor semantics), returning "
700 "result";
701 return newXferWrite->getResults()[0];
702 }
703
704 // With Memref semantics, there's no return value. Use empty value to signal
705 // success.
706 LDBG() << " -> Pattern match successful (memref semantics)";
707 return Value();
708 }
709};
710
711} // namespace
712
713/// Creates a memref.collapse_shape collapsing all inner dimensions of the
714/// input starting at `firstDimToCollapse`.
716 Value input, int64_t firstDimToCollapse) {
717 ShapedType inputType = cast<ShapedType>(input.getType());
718 if (inputType.getRank() == 1)
719 return input;
721 for (int64_t i = 0; i < firstDimToCollapse; ++i)
722 reassociation.push_back(ReassociationIndices{i});
723 ReassociationIndices collapsedIndices;
724 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
725 collapsedIndices.push_back(i);
726 reassociation.push_back(collapsedIndices);
727 return memref::CollapseShapeOp::create(rewriter, loc, input, reassociation);
728}
729
730/// Returns the new indices that collapses the inner dimensions starting from
731/// the `firstDimToCollapse` dimension.
733 Location loc,
736 int64_t firstDimToCollapse) {
737 assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
738
739 // If all the collapsed indices are zero then no extra logic is needed.
740 // Otherwise, a new offset/index has to be computed.
741 SmallVector<Value> indicesAfterCollapsing(
742 indices.begin(), indices.begin() + firstDimToCollapse);
743 SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
744 indices.end());
745 if (llvm::all_of(indicesToCollapse, isZeroInteger)) {
746 indicesAfterCollapsing.push_back(indicesToCollapse[0]);
747 return indicesAfterCollapsing;
748 }
749
750 // Compute the remaining trailing index/offset required for reading from
751 // the collapsed memref:
752 //
753 // offset = 0
754 // for (i = firstDimToCollapse; i < outputRank; ++i)
755 // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
756 //
757 // For this example:
758 // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
759 // memref<1x43x2xi32>, vector<1x2xi32>
760 // which would be collapsed to:
761 // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
762 // memref<1x86xi32>, vector<2xi32>
763 // one would get the following offset:
764 // %offset = %arg0 * 43
765 OpFoldResult collapsedOffset =
766 arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
767
768 auto collapsedStrides = computeSuffixProduct(
769 ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
770
771 // Compute the collapsed offset.
772 auto &&[collapsedExpr, collapsedVals] =
773 computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
775 rewriter, loc, collapsedExpr, collapsedVals);
776
777 if (auto value = dyn_cast<Value>(collapsedOffset)) {
778 indicesAfterCollapsing.push_back(value);
779 } else {
780 indicesAfterCollapsing.push_back(arith::ConstantIndexOp::create(
781 rewriter, loc, *getConstantIntValue(collapsedOffset)));
782 }
783
784 return indicesAfterCollapsing;
785}
786
787namespace {
788/// Rewrites contiguous row-major vector.transfer_read ops by inserting
789/// memref.collapse_shape on the source so that the resulting
790/// vector.transfer_read has a 1D source. Requires the source shape to be
791/// already reduced i.e. without unit dims.
792///
793/// If `targetVectorBitwidth` is provided, the flattening will only happen if
794/// the trailing dimension of the vector read is smaller than the provided
795/// bitwidth.
796class FlattenContiguousRowMajorTransferReadPattern
797 : public OpRewritePattern<vector::TransferReadOp> {
798public:
799 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
800 unsigned vectorBitwidth,
801 PatternBenefit benefit)
802 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
803 targetVectorBitwidth(vectorBitwidth) {}
804
805 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
806 PatternRewriter &rewriter) const override {
807 LDBG() << "=== FlattenContiguousRowMajorTransferReadPattern: Analyzing "
808 << *transferReadOp;
809 auto loc = transferReadOp.getLoc();
810 Value vector = transferReadOp.getVector();
811 VectorType vectorType = cast<VectorType>(vector.getType());
812 auto source = transferReadOp.getBase();
813 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
814
815 // 0. Check pre-conditions
816 // Contiguity check is valid on tensors only.
817 if (!sourceType) {
818 LDBG() << " -> Not a MemRefType, skipping";
819 return failure();
820 }
821 // If this is already 0D/1D, there's nothing to do.
822 if (vectorType.getRank() <= 1) {
823 LDBG() << " -> Already 0D/1D, skipping";
824 return failure();
825 }
826 if (!vectorType.getElementType().isSignlessIntOrFloat()) {
827 LDBG() << " -> Not signless int or float, skipping";
828 return failure();
829 }
830 unsigned trailingVectorDimBitwidth =
831 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
832 LDBG() << " -> Trailing vector dim bitwidth: " << trailingVectorDimBitwidth
833 << ", target: " << targetVectorBitwidth;
834 if (trailingVectorDimBitwidth >= targetVectorBitwidth) {
835 LDBG() << " -> Trailing dim bitwidth >= target, skipping";
836 return failure();
837 }
838 if (!vector::isContiguousSlice(sourceType, vectorType)) {
839 LDBG() << " -> Not contiguous slice, skipping";
840 return failure();
841 }
842 // TODO: generalize this pattern, relax the requirements here.
843 if (transferReadOp.hasOutOfBoundsDim()) {
844 LDBG() << " -> Has out-of-bounds dimensions, skipping";
845 return failure();
846 }
847 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
848 LDBG() << " -> Not minor identity permutation map, skipping";
849 return failure();
850 }
851 if (transferReadOp.getMask()) {
852 LDBG() << " -> Has mask, skipping";
853 return failure();
854 }
855
856 // Determine vector dimensions to collapse.
857 // Ignore a leading sequence of adjacent unit dimensions in the vector.
858 ArrayRef<int64_t> collapsedVectorShape =
859 vectorType.getShape().drop_while([](auto v) { return v == 1; });
860 size_t collapsedVecRank = collapsedVectorShape.size();
861 // Limit the collapse of multi-dimensional unit vectors (e.g. <1x1x1xf32>)
862 // to a 1D single-element vector.
863 if (collapsedVecRank == 0)
864 collapsedVecRank = 1;
865
866 // Determine the first memref dimension to collapse - just enough so we can
867 // read a flattened vector.
868 int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
869 LDBG() << " -> First dimension to collapse: " << firstDimToCollapse;
870
871 // 1. Collapse the source memref
872 LDBG() << " -> Collapsing source memref";
873 Value collapsedSource =
874 collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
875 MemRefType collapsedSourceType =
876 cast<MemRefType>(collapsedSource.getType());
877 int64_t collapsedRank = collapsedSourceType.getRank();
878 assert(collapsedRank == firstDimToCollapse + 1);
879 LDBG() << " -> Collapsed source type: " << collapsedSourceType;
880
881 // 2. Generate input args for a new vector.transfer_read that will read
882 // from the collapsed memref.
883 // 2.1. New dim exprs + affine map
884 SmallVector<AffineExpr, 1> dimExprs{
885 getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
886 auto collapsedMap =
887 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
888
889 // 2.2 New indices
890 SmallVector<Value> collapsedIndices =
891 getCollapsedIndices(rewriter, loc, sourceType.getShape(),
892 transferReadOp.getIndices(), firstDimToCollapse);
893
894 // 3. Create new vector.transfer_read that reads from the collapsed memref
895 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
896 vectorType.getElementType());
897 LDBG() << " -> Creating flattened vector type: " << flatVectorType;
898 vector::TransferReadOp flatRead = vector::TransferReadOp::create(
899 rewriter, loc, flatVectorType, collapsedSource, collapsedIndices,
900 transferReadOp.getPadding(), collapsedMap);
901 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
902 LDBG() << " -> Created flat transfer_read: " << *flatRead;
903
904 // 4. Replace the old transfer_read with the new one reading from the
905 // collapsed shape
906 LDBG() << " -> Replacing with shape cast";
907 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
908 transferReadOp, cast<VectorType>(vector.getType()), flatRead);
909 LDBG() << " -> Pattern match successful";
910 return success();
911 }
912
913private:
914 // Minimum bitwidth that the trailing vector dimension should have after
915 // flattening.
916 unsigned targetVectorBitwidth;
917};
918
919/// Rewrites contiguous row-major vector.transfer_write ops by inserting
920/// memref.collapse_shape on the source so that the resulting
921/// vector.transfer_write has a 1D source. Requires the source shape to be
922/// already reduced i.e. without unit dims.
923///
924/// If `targetVectorBitwidth` is provided, the flattening will only happen if
925/// the trailing dimension of the vector read is smaller than the provided
926/// bitwidth.
927class FlattenContiguousRowMajorTransferWritePattern
928 : public OpRewritePattern<vector::TransferWriteOp> {
929public:
930 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
931 unsigned vectorBitwidth,
932 PatternBenefit benefit)
933 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
934 targetVectorBitwidth(vectorBitwidth) {}
935
936 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
937 PatternRewriter &rewriter) const override {
938 auto loc = transferWriteOp.getLoc();
939 Value vector = transferWriteOp.getVector();
940 VectorType vectorType = cast<VectorType>(vector.getType());
941 Value source = transferWriteOp.getBase();
942 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
943
944 // 0. Check pre-conditions
945 // Contiguity check is valid on tensors only.
946 if (!sourceType)
947 return failure();
948 // If this is already 0D/1D, there's nothing to do.
949 if (vectorType.getRank() <= 1)
950 // Already 0D/1D, nothing to do.
951 return failure();
952 if (!vectorType.getElementType().isSignlessIntOrFloat())
953 return failure();
954 unsigned trailingVectorDimBitwidth =
955 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
956 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
957 return failure();
958 if (!vector::isContiguousSlice(sourceType, vectorType))
959 return failure();
960 // TODO: generalize this pattern, relax the requirements here.
961 if (transferWriteOp.hasOutOfBoundsDim())
962 return failure();
963 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
964 return failure();
965 if (transferWriteOp.getMask())
966 return failure();
967
968 // Determine vector dimensions to collapse.
969 // Ignore a leading sequence of adjacent unit dimensions in the vector.
970 ArrayRef<int64_t> collapsedVectorShape =
971 vectorType.getShape().drop_while([](auto v) { return v == 1; });
972 size_t collapsedVecRank = collapsedVectorShape.size();
973 // Limit the collapse of multi-dimensional unit vectors (e.g. <1x1x1xf32>)
974 // to a 1D single-element vector.
975 if (collapsedVecRank == 0)
976 collapsedVecRank = 1;
977
978 // Determine the first memref dimension to collapse - just enough so we can
979 // read a flattened vector.
980 int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
981
982 // 1. Collapse the source memref
983 Value collapsedSource =
984 collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
985 MemRefType collapsedSourceType =
986 cast<MemRefType>(collapsedSource.getType());
987 int64_t collapsedRank = collapsedSourceType.getRank();
988 assert(collapsedRank == firstDimToCollapse + 1);
989
990 // 2. Generate input args for a new vector.transfer_read that will read
991 // from the collapsed memref.
992 // 2.1. New dim exprs + affine map
993 SmallVector<AffineExpr, 1> dimExprs{
994 getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
995 auto collapsedMap =
996 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
997
998 // 2.2 New indices
999 SmallVector<Value> collapsedIndices =
1000 getCollapsedIndices(rewriter, loc, sourceType.getShape(),
1001 transferWriteOp.getIndices(), firstDimToCollapse);
1002
1003 // 3. Create new vector.transfer_write that writes to the collapsed memref
1004 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
1005 vectorType.getElementType());
1006 Value flatVector =
1007 vector::ShapeCastOp::create(rewriter, loc, flatVectorType, vector);
1008 vector::TransferWriteOp flatWrite = vector::TransferWriteOp::create(
1009 rewriter, loc, flatVector, collapsedSource, collapsedIndices,
1010 collapsedMap);
1011 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1012
1013 // 4. Replace the old transfer_write with the new one writing the
1014 // collapsed shape
1015 rewriter.eraseOp(transferWriteOp);
1016 return success();
1017 }
1018
1019private:
1020 // Minimum bitwidth that the trailing vector dimension should have after
1021 // flattening.
1022 unsigned targetVectorBitwidth;
1023};
1024
1025/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
1026///
1027/// All the users of the transfer op must be `vector.extract` ops. If
1028/// `allowMultipleUses` is set to true, rewrite transfer ops with any number of
1029/// users. Otherwise, rewrite only if the extract op is the single user of the
1030/// transfer op. Rewriting a single vector load with multiple scalar loads may
1031/// negatively affect performance.
1032class RewriteScalarExtractOfTransferRead
1033 : public OpRewritePattern<vector::ExtractOp> {
1034public:
1035 RewriteScalarExtractOfTransferRead(MLIRContext *context,
1036 PatternBenefit benefit,
1037 bool allowMultipleUses)
1038 : OpRewritePattern(context, benefit),
1039 allowMultipleUses(allowMultipleUses) {}
1040
1041 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1042 PatternRewriter &rewriter) const override {
1043 // Match phase.
1044 auto xferOp = extractOp.getSource().getDefiningOp<vector::TransferReadOp>();
1045 if (!xferOp)
1046 return failure();
1047 // Check that we are extracting a scalar and not a sub-vector.
1048 if (isa<VectorType>(extractOp.getResult().getType()))
1049 return failure();
1050 // If multiple uses are not allowed, check if xfer has a single use.
1051 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
1052 return failure();
1053 // If multiple uses are allowed, check if all the xfer uses are extract ops.
1054 if (allowMultipleUses &&
1055 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
1056 return isa<vector::ExtractOp>(use.getOwner());
1057 }))
1058 return failure();
1059 // Mask not supported.
1060 if (xferOp.getMask())
1061 return failure();
1062 // Map not supported.
1063 if (!xferOp.getPermutationMap().isMinorIdentity())
1064 return failure();
1065 // Cannot rewrite if the indices may be out of bounds.
1066 if (xferOp.hasOutOfBoundsDim())
1067 return failure();
1068
1069 // Rewrite phase: construct scalar load.
1070 SmallVector<Value> newIndices(xferOp.getIndices().begin(),
1071 xferOp.getIndices().end());
1072 for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1073 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
1074
1075 // Compute affine expression `newIndices[idx] + pos` where `pos` can be
1076 // either a constant or a value.
1077 OpFoldResult composedIdx;
1078 if (auto attr = dyn_cast<Attribute>(pos)) {
1079 int64_t offset = cast<IntegerAttr>(attr).getInt();
1081 rewriter, extractOp.getLoc(),
1082 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
1083 } else {
1084 Value dynamicOffset = cast<Value>(pos);
1085 AffineExpr sym0, sym1;
1086 bindSymbols(rewriter.getContext(), sym0, sym1);
1088 rewriter, extractOp.getLoc(), sym0 + sym1,
1089 {newIndices[idx], dynamicOffset});
1090 }
1091
1092 // Update the corresponding index with the folded result.
1093 if (auto value = dyn_cast<Value>(composedIdx)) {
1094 newIndices[idx] = value;
1095 } else {
1096 newIndices[idx] = arith::ConstantIndexOp::create(
1097 rewriter, extractOp.getLoc(), *getConstantIntValue(composedIdx));
1098 }
1099 }
1100 if (isa<MemRefType>(xferOp.getBase().getType())) {
1101 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
1102 newIndices);
1103 } else {
1104 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1105 extractOp, xferOp.getBase(), newIndices);
1106 }
1107
1108 return success();
1109 }
1110
1111private:
1112 bool allowMultipleUses;
1113};
1114
1115/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
1116/// to memref.store.
1117class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
1118 using Base::Base;
1119
1120 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1121 PatternRewriter &rewriter) const override {
1122 // Must be a scalar write.
1123 auto vecType = xferOp.getVectorType();
1124 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
1125 return failure();
1126 // Mask not supported.
1127 if (xferOp.getMask())
1128 return failure();
1129 // Map not supported.
1130 if (!xferOp.getPermutationMap().isMinorIdentity())
1131 return failure();
1132 // Only float and integer element types are supported.
1133 Value scalar = vector::ExtractOp::create(rewriter, xferOp.getLoc(),
1134 xferOp.getVector());
1135 // Construct a scalar store.
1136 if (isa<MemRefType>(xferOp.getBase().getType())) {
1137 rewriter.replaceOpWithNewOp<memref::StoreOp>(
1138 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1139 } else {
1140 rewriter.replaceOpWithNewOp<tensor::InsertOp>(
1141 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1142 }
1143 return success();
1144 }
1145};
1146
1147} // namespace
1148
1149void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
1150 Operation *rootOp) {
1151 LDBG() << "=== Starting transferOpflowOpt on root operation: "
1152 << OpWithFlags(rootOp, OpPrintingFlags().skipRegions());
1153 TransferOptimization opt(rewriter, rootOp);
1154
1155 // Run store to load forwarding first since it can expose more dead store
1156 // opportunity.
1157 LDBG() << "Phase 1: Store-to-load forwarding";
1158 int readCount = 0;
1159 rootOp->walk([&](vector::TransferReadOp read) {
1160 if (isa<MemRefType>(read.getShapedType())) {
1161 LDBG() << "Processing transfer_read #" << ++readCount << ": " << *read;
1162 opt.storeToLoadForwarding(read);
1163 }
1164 });
1165 LDBG() << "Phase 1 complete. Removing dead operations from forwarding";
1166 opt.removeDeadOp();
1167
1168 LDBG() << "Phase 2: Dead store elimination";
1169 int writeCount = 0;
1170 rootOp->walk([&](vector::TransferWriteOp write) {
1171 if (isa<MemRefType>(write.getShapedType())) {
1172 LDBG() << "Processing transfer_write #" << ++writeCount << ": " << *write;
1173 opt.deadStoreOp(write);
1174 }
1175 });
1176 LDBG() << "Phase 2 complete. Removing dead operations from dead store "
1177 "elimination";
1178 opt.removeDeadOp();
1179 LDBG() << "=== transferOpflowOpt complete";
1180}
1181
1183 RewritePatternSet &patterns, PatternBenefit benefit,
1184 bool allowMultipleUses) {
1185 patterns.add<RewriteScalarExtractOfTransferRead>(patterns.getContext(),
1186 benefit, allowMultipleUses);
1187 patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
1188}
1189
1190void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
1191 RewritePatternSet &patterns, PatternBenefit benefit) {
1192 patterns
1193 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
1194 patterns.getContext(), benefit);
1195}
1196
1197void mlir::vector::populateFlattenVectorTransferPatterns(
1198 RewritePatternSet &patterns, unsigned targetVectorBitwidth,
1199 PatternBenefit benefit) {
1200 patterns.add<FlattenContiguousRowMajorTransferReadPattern,
1201 FlattenContiguousRowMajorTransferWritePattern>(
1202 patterns.getContext(), targetVectorBitwidth, benefit);
1203 populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
1204}
return success()
static SmallVector< int64_t > getReducedShape(ArrayRef< OpFoldResult > mixedSizes)
Converts OpFoldResults to int64_t shape without unit dims.
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input)
Creates a rank-reducing memref.subview op that drops unit dims from its input.
static SmallVector< Value > getCollapsedIndices(RewriterBase &rewriter, Location loc, ArrayRef< int64_t > shape, ValueRange indices, int64_t firstDimToCollapse)
Returns the new indices that collapses the inner dimensions starting from the firstDimToCollapse dime...
static int getReducedRank(ArrayRef< int64_t > shape)
Returns the number of dims that aren't unit dims.
static FailureOr< Value > maskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, MaskOp op)
Rewrites a mask op to drop non-scalable unit dimensions.
static bool isUnitDimMask(Value maskDimSize)
static VectorType trimNonScalableUnitDims(VectorType oldType)
Trims non-scalable one dimensions from oldType and returns the result type.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
Drops unit dimensions from the input MemRefType.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
Definition Block.cpp:368
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:274
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:161
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
result_range getResults()
Definition Operation.h:441
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:248
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool postDominates(Operation *a, Operation *b) const
Return true if operation A postdominates operation B.
Definition Dominance.h:216
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.