MLIR  22.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/Operation.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/DebugLog.h"
30 
31 #define DEBUG_TYPE "vector-transfer-opt"
32 
33 using namespace mlir;
34 
35 /// Return the ancestor op in the region or nullptr if the region is not
36 /// an ancestor of the op.
38  LDBG() << " Finding ancestor of " << *op << " in region";
39  for (; op != nullptr && op->getParentRegion() != region;
40  op = op->getParentOp())
41  ;
42  if (op) {
43  LDBG() << " -> Ancestor: " << *op;
44  } else {
45  LDBG() << " -> Ancestor: nullptr";
46  }
47  return op;
48 }
49 
50 namespace {
51 
52 class TransferOptimization {
53 public:
54  TransferOptimization(RewriterBase &rewriter, Operation *op)
55  : rewriter(rewriter), dominators(op), postDominators(op) {}
56  void deadStoreOp(vector::TransferWriteOp);
57  void storeToLoadForwarding(vector::TransferReadOp);
58  void removeDeadOp() {
59  LDBG() << "Removing " << opToErase.size() << " dead operations";
60  for (Operation *op : opToErase) {
61  LDBG() << " -> Erasing: " << *op;
62  rewriter.eraseOp(op);
63  }
64  opToErase.clear();
65  }
66 
67 private:
68  RewriterBase &rewriter;
69  bool isReachable(Operation *start, Operation *dest);
70  DominanceInfo dominators;
71  PostDominanceInfo postDominators;
72  std::vector<Operation *> opToErase;
73 };
74 
75 } // namespace
76 /// Return true if there is a path from start operation to dest operation,
77 /// otherwise return false. The operations have to be in the same region.
78 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
79  LDBG() << " Checking reachability from " << *start << " to " << *dest;
80  assert(start->getParentRegion() == dest->getParentRegion() &&
81  "This function only works for ops i the same region");
82  // Simple case where the start op dominate the destination.
83  if (dominators.dominates(start, dest)) {
84  LDBG() << " -> Start dominates dest, reachable";
85  return true;
86  }
87  bool blockReachable = start->getBlock()->isReachable(dest->getBlock());
88  LDBG() << " -> Block reachable: " << blockReachable;
89  return blockReachable;
90 }
91 
92 /// For transfer_write to overwrite fully another transfer_write must:
93 /// 1. Access the same memref with the same indices and vector type.
94 /// 2. Post-dominate the other transfer_write operation.
95 /// If several candidates are available, one must be post-dominated by all the
96 /// others since they are all post-dominating the same transfer_write. We only
97 /// consider the transfer_write post-dominated by all the other candidates as
98 /// this will be the first transfer_write executed after the potentially dead
99 /// transfer_write.
100 /// If we found such an overwriting transfer_write we know that the original
101 /// transfer_write is dead if all reads that can be reached from the potentially
102 /// dead transfer_write are dominated by the overwriting transfer_write.
103 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
104  LDBG() << "=== Starting deadStoreOp analysis for: " << *write.getOperation();
105  llvm::SmallVector<Operation *, 8> blockingAccesses;
106  Operation *firstOverwriteCandidate = nullptr;
107  Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
108  LDBG() << "Source memref (after skipping view-like ops): " << source;
109  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
110  source.getUsers().end());
111  LDBG() << "Found " << users.size() << " users of source memref";
112  llvm::SmallDenseSet<Operation *, 32> processed;
113  while (!users.empty()) {
114  Operation *user = users.pop_back_val();
115  LDBG() << "Processing user: " << *user;
116  // If the user has already been processed skip.
117  if (!processed.insert(user).second) {
118  LDBG() << " -> Already processed, skipping";
119  continue;
120  }
121  if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
122  LDBG() << " -> View-like operation, following to destination";
123  Value viewDest = viewLike.getViewDest();
124  users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
125  continue;
126  }
127  if (isMemoryEffectFree(user)) {
128  LDBG() << " -> Memory effect free, skipping";
129  continue;
130  }
131  if (user == write.getOperation()) {
132  LDBG() << " -> Same as write operation, skipping";
133  continue;
134  }
135  if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
136  LDBG() << " -> Found transfer_write candidate: " << *nextWrite;
137  // Check candidate that can override the store.
138  bool sameView = memref::isSameViewOrTrivialAlias(
139  cast<MemrefValue>(nextWrite.getBase()),
140  cast<MemrefValue>(write.getBase()));
141  bool sameValue = checkSameValueWAW(nextWrite, write);
142  bool postDominates = postDominators.postDominates(nextWrite, write);
143  LDBG() << " -> Same view: " << sameView
144  << ", Same value: " << sameValue
145  << ", Post-dominates: " << postDominates;
146 
147  if (sameView && sameValue && postDominates) {
148  LDBG() << " -> Valid overwrite candidate found";
149  if (firstOverwriteCandidate == nullptr ||
150  postDominators.postDominates(firstOverwriteCandidate, nextWrite)) {
151  LDBG() << " -> New first overwrite candidate: " << *nextWrite;
152  firstOverwriteCandidate = nextWrite;
153  } else {
154  LDBG() << " -> Keeping existing first overwrite candidate";
155  assert(
156  postDominators.postDominates(nextWrite, firstOverwriteCandidate));
157  }
158  continue;
159  }
160  LDBG() << " -> Not a valid overwrite candidate";
161  }
162  if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
163  LDBG() << " -> Found vector transfer operation: " << *transferOp;
164  // Don't need to consider disjoint accesses.
165  bool isDisjoint = vector::isDisjointTransferSet(
166  cast<VectorTransferOpInterface>(write.getOperation()),
167  cast<VectorTransferOpInterface>(transferOp.getOperation()),
168  /*testDynamicValueUsingBounds=*/true);
169  LDBG() << " -> Is disjoint: " << isDisjoint;
170  if (isDisjoint) {
171  LDBG() << " -> Skipping disjoint access";
172  continue;
173  }
174  }
175  LDBG() << " -> Adding to blocking accesses: " << *user;
176  blockingAccesses.push_back(user);
177  }
178  LDBG() << "Finished processing users. Found " << blockingAccesses.size()
179  << " blocking accesses";
180 
181  if (firstOverwriteCandidate == nullptr) {
182  LDBG() << "No overwrite candidate found, store is not dead";
183  return;
184  }
185 
186  LDBG() << "First overwrite candidate: " << *firstOverwriteCandidate;
187  Region *topRegion = firstOverwriteCandidate->getParentRegion();
188  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
189  assert(writeAncestor &&
190  "write op should be recursively part of the top region");
191  LDBG() << "Write ancestor in top region: " << *writeAncestor;
192 
193  LDBG() << "Checking " << blockingAccesses.size()
194  << " blocking accesses for reachability";
195  for (Operation *access : blockingAccesses) {
196  LDBG() << "Checking blocking access: " << *access;
197  Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
198  // TODO: if the access and write have the same ancestor we could recurse in
199  // the region to know if the access is reachable with more precision.
200  if (accessAncestor == nullptr) {
201  LDBG() << " -> No ancestor in top region, skipping";
202  continue;
203  }
204 
205  bool isReachableFromWrite = isReachable(writeAncestor, accessAncestor);
206  LDBG() << " -> Is reachable from write: " << isReachableFromWrite;
207  if (!isReachableFromWrite) {
208  LDBG() << " -> Not reachable, skipping";
209  continue;
210  }
211 
212  bool overwriteDominatesAccess =
213  dominators.dominates(firstOverwriteCandidate, accessAncestor);
214  LDBG() << " -> Overwrite dominates access: " << overwriteDominatesAccess;
215  if (!overwriteDominatesAccess) {
216  LDBG() << "Store may not be dead due to op: " << *accessAncestor;
217  return;
218  }
219  LDBG() << " -> Access is dominated by overwrite, continuing";
220  }
221  LDBG() << "Found dead store: " << *write.getOperation()
222  << " overwritten by: " << *firstOverwriteCandidate;
223  opToErase.push_back(write.getOperation());
224 }
225 
226 /// A transfer_write candidate to storeToLoad forwarding must:
227 /// 1. Access the same memref with the same indices and vector type as the
228 /// transfer_read.
229 /// 2. Dominate the transfer_read operation.
230 /// If several candidates are available, one must be dominated by all the others
231 /// since they are all dominating the same transfer_read. We only consider the
232 /// transfer_write dominated by all the other candidates as this will be the
233 /// last transfer_write executed before the transfer_read.
234 /// If we found such a candidate we can do the forwarding if all the other
235 /// potentially aliasing ops that may reach the transfer_read are post-dominated
236 /// by the transfer_write.
237 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
238  LDBG() << "=== Starting storeToLoadForwarding analysis for: "
239  << *read.getOperation();
240  if (read.hasOutOfBoundsDim()) {
241  LDBG() << "Read has out-of-bounds dimensions, skipping";
242  return;
243  }
244  SmallVector<Operation *, 8> blockingWrites;
245  vector::TransferWriteOp lastwrite = nullptr;
246  Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
247  LDBG() << "Source memref (after skipping view-like ops): " << source;
248  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
249  source.getUsers().end());
250  LDBG() << "Found " << users.size() << " users of source memref";
251  llvm::SmallDenseSet<Operation *, 32> processed;
252  while (!users.empty()) {
253  Operation *user = users.pop_back_val();
254  LDBG() << "Processing user: " << *user;
255  // If the user has already been processed skip.
256  if (!processed.insert(user).second) {
257  LDBG() << " -> Already processed, skipping";
258  continue;
259  }
260  if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
261  LDBG() << " -> View-like operation, following to destination";
262  Value viewDest = viewLike.getViewDest();
263  users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
264  continue;
265  }
266  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) {
267  LDBG() << " -> Memory effect free or transfer_read, skipping";
268  continue;
269  }
270  if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
271  LDBG() << " -> Found transfer_write candidate: " << *write;
272  // If there is a write, but we can prove that it is disjoint we can ignore
273  // the write.
274  bool isDisjoint = vector::isDisjointTransferSet(
275  cast<VectorTransferOpInterface>(write.getOperation()),
276  cast<VectorTransferOpInterface>(read.getOperation()),
277  /*testDynamicValueUsingBounds=*/true);
278  LDBG() << " -> Is disjoint: " << isDisjoint;
279  if (isDisjoint) {
280  LDBG() << " -> Skipping disjoint write";
281  continue;
282  }
283 
284  bool sameView =
285  memref::isSameViewOrTrivialAlias(cast<MemrefValue>(read.getBase()),
286  cast<MemrefValue>(write.getBase()));
287  bool dominates = dominators.dominates(write, read);
288  bool sameValue = checkSameValueRAW(write, read);
289  LDBG() << " -> Same view: " << sameView << ", Dominates: " << dominates
290  << ", Same value: " << sameValue;
291 
292  if (sameView && dominates && sameValue) {
293  LDBG() << " -> Valid forwarding candidate found";
294  if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) {
295  LDBG() << " -> New last write candidate: " << *write;
296  lastwrite = write;
297  } else {
298  LDBG() << " -> Keeping existing last write candidate";
299  assert(dominators.dominates(write, lastwrite));
300  }
301  continue;
302  }
303  LDBG() << " -> Not a valid forwarding candidate";
304  }
305  LDBG() << " -> Adding to blocking writes: " << *user;
306  blockingWrites.push_back(user);
307  }
308  LDBG() << "Finished processing users. Found " << blockingWrites.size()
309  << " blocking writes";
310 
311  if (lastwrite == nullptr) {
312  LDBG() << "No last write candidate found, cannot forward";
313  return;
314  }
315 
316  LDBG() << "Last write candidate: " << *lastwrite;
317  Region *topRegion = lastwrite->getParentRegion();
318  Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
319  assert(readAncestor &&
320  "read op should be recursively part of the top region");
321  LDBG() << "Read ancestor in top region: " << *readAncestor;
322 
323  LDBG() << "Checking " << blockingWrites.size()
324  << " blocking writes for post-dominance";
325  for (Operation *write : blockingWrites) {
326  LDBG() << "Checking blocking write: " << *write;
327  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
328  if (writeAncestor) {
329  LDBG() << " -> Write ancestor: " << *writeAncestor;
330  } else {
331  LDBG() << " -> Write ancestor: nullptr";
332  }
333 
334  // TODO: if the store and read have the same ancestor we could recurse in
335  // the region to know if the read is reachable with more precision.
336  if (writeAncestor == nullptr) {
337  LDBG() << " -> No ancestor in top region, skipping";
338  continue;
339  }
340 
341  bool isReachableToRead = isReachable(writeAncestor, readAncestor);
342  LDBG() << " -> Is reachable to read: " << isReachableToRead;
343  if (!isReachableToRead) {
344  LDBG() << " -> Not reachable, skipping";
345  continue;
346  }
347 
348  bool lastWritePostDominates =
349  postDominators.postDominates(lastwrite, write);
350  LDBG() << " -> Last write post-dominates blocking write: "
351  << lastWritePostDominates;
352  if (!lastWritePostDominates) {
353  LDBG() << "Fail to do write to read forwarding due to op: " << *write;
354  return;
355  }
356  LDBG() << " -> Blocking write is post-dominated, continuing";
357  }
358 
359  LDBG() << "Forward value from " << *lastwrite.getOperation()
360  << " to: " << *read.getOperation();
361  read.replaceAllUsesWith(lastwrite.getVector());
362  opToErase.push_back(read.getOperation());
363 }
364 
365 /// Converts OpFoldResults to int64_t shape without unit dims.
367  SmallVector<int64_t> reducedShape;
368  for (const auto size : mixedSizes) {
369  if (llvm::dyn_cast_if_present<Value>(size)) {
370  reducedShape.push_back(ShapedType::kDynamic);
371  continue;
372  }
373 
374  auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
375  if (value == 1)
376  continue;
377  reducedShape.push_back(value.getSExtValue());
378  }
379  return reducedShape;
380 }
381 
382 /// Drops unit dimensions from the input MemRefType.
383 static MemRefType dropUnitDims(MemRefType inputType,
384  ArrayRef<OpFoldResult> offsets,
386  ArrayRef<OpFoldResult> strides) {
387  auto targetShape = getReducedShape(sizes);
388  MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
389  targetShape, inputType, offsets, sizes, strides);
390  return rankReducedType.canonicalizeStridedLayout();
391 }
392 
393 /// Creates a rank-reducing memref.subview op that drops unit dims from its
394 /// input. Or just returns the input if it was already without unit dims.
396  mlir::Location loc,
397  Value input) {
398  MemRefType inputType = cast<MemRefType>(input.getType());
399  SmallVector<OpFoldResult> offsets(inputType.getRank(),
400  rewriter.getIndexAttr(0));
401  SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
402  SmallVector<OpFoldResult> strides(inputType.getRank(),
403  rewriter.getIndexAttr(1));
404  MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
405 
406  if (resultType.canonicalizeStridedLayout() ==
407  inputType.canonicalizeStridedLayout())
408  return input;
409  return memref::SubViewOp::create(rewriter, loc, resultType, input, offsets,
410  sizes, strides);
411 }
412 
413 /// Returns the number of dims that aren't unit dims.
415  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
416 }
417 
418 /// Trims non-scalable one dimensions from `oldType` and returns the result
419 /// type.
420 static VectorType trimNonScalableUnitDims(VectorType oldType) {
421  SmallVector<int64_t> newShape;
422  SmallVector<bool> newScalableDims;
423  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
424  if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
425  continue;
426  newShape.push_back(dimSize);
427  newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
428  }
429  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
430 }
431 
432 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
433 static FailureOr<Value>
435  vector::CreateMaskOp op) {
436  auto type = op.getType();
437  VectorType reducedType = trimNonScalableUnitDims(type);
438  if (reducedType.getRank() == type.getRank())
439  return failure();
440 
441  SmallVector<Value> reducedOperands;
442  for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
443  type.getShape(), type.getScalableDims(), op.getOperands())) {
444  if (dim == 1 && !dimIsScalable) {
445  // If the mask for the unit dim is not a constant of 1, do nothing.
446  auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
447  if (!constant || (constant.value() != 1))
448  return failure();
449  continue;
450  }
451  reducedOperands.push_back(operand);
452  }
453  return vector::CreateMaskOp::create(rewriter, loc, reducedType,
454  reducedOperands)
455  .getResult();
456 }
457 
458 namespace {
459 
460 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
461 /// inserting a memref.subview dropping those unit dims. The vector shapes are
462 /// also reduced accordingly.
463 class TransferReadDropUnitDimsPattern
464  : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
465  using MaskableOpRewritePattern::MaskableOpRewritePattern;
466 
467  FailureOr<Value>
468  matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
469  vector::MaskingOpInterface maskingOp,
470  PatternRewriter &rewriter) const override {
471  LDBG() << "=== TransferReadDropUnitDimsPattern: Analyzing "
472  << *transferReadOp;
473  auto loc = transferReadOp.getLoc();
474  Value vector = transferReadOp.getVector();
475  VectorType vectorType = cast<VectorType>(vector.getType());
476  Value source = transferReadOp.getBase();
477  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
478  // TODO: support tensor types.
479  if (!sourceType) {
480  LDBG() << " -> Not a MemRefType, skipping";
481  return failure();
482  }
483  // TODO: generalize this pattern, relax the requirements here.
484  if (transferReadOp.hasOutOfBoundsDim()) {
485  LDBG() << " -> Has out-of-bounds dimensions, skipping";
486  return failure();
487  }
488  if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
489  LDBG() << " -> Not minor identity permutation map, skipping";
490  return failure();
491  }
492  // Check if the source shape can be further reduced.
493  int reducedRank = getReducedRank(sourceType.getShape());
494  LDBG() << " -> Source rank: " << sourceType.getRank()
495  << ", Reduced rank: " << reducedRank;
496  if (reducedRank == sourceType.getRank()) {
497  LDBG() << " -> No unit dimensions to drop, skipping";
498  return failure();
499  }
500  // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
501  // out.
502  if (reducedRank == 0 && maskingOp) {
503  LDBG() << " -> 0-d vector with masking not supported, skipping";
504  return failure();
505  }
506  // Check if the reduced vector shape matches the reduced source shape.
507  // Otherwise, this case is not supported yet.
508  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
509  LDBG() << " -> Vector type: " << vectorType
510  << ", Reduced vector type: " << reducedVectorType;
511  if (reducedRank != reducedVectorType.getRank()) {
512  LDBG() << " -> Reduced ranks don't match, skipping";
513  return failure();
514  }
515  if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
516  return getConstantIntValue(v) != static_cast<int64_t>(0);
517  })) {
518  LDBG() << " -> Non-zero indices found, skipping";
519  return failure();
520  }
521 
522  Value maskOp = transferReadOp.getMask();
523  if (maskOp) {
524  LDBG() << " -> Processing mask operation";
525  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
526  if (!createMaskOp) {
527  LDBG()
528  << " -> Unsupported mask op, only 'vector.create_mask' supported";
529  return rewriter.notifyMatchFailure(
530  transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
531  "currently supported");
532  }
533  FailureOr<Value> rankReducedCreateMask =
534  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
535  if (failed(rankReducedCreateMask)) {
536  LDBG() << " -> Failed to reduce mask dimensions";
537  return failure();
538  }
539  maskOp = *rankReducedCreateMask;
540  LDBG() << " -> Successfully reduced mask dimensions";
541  }
542 
543  LDBG() << " -> Creating rank-reduced subview and new transfer_read";
544  Value reducedShapeSource =
545  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
546  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
547  SmallVector<Value> zeros(reducedRank, c0);
548  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
549  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
550  Operation *newTransferReadOp = vector::TransferReadOp::create(
551  rewriter, loc, reducedVectorType, reducedShapeSource, zeros,
552  identityMap, transferReadOp.getPadding(), maskOp,
553  rewriter.getBoolArrayAttr(inBounds));
554  LDBG() << " -> Created new transfer_read: " << *newTransferReadOp;
555 
556  if (maskingOp) {
557  LDBG() << " -> Applying masking operation";
558  auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
559  loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
560  maskingOp.getMask());
561  newTransferReadOp = mlir::vector::maskOperation(
562  rewriter, newTransferReadOp, shapeCastMask);
563  }
564 
565  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
566  loc, vectorType, newTransferReadOp->getResults()[0]);
567  LDBG() << " -> Created shape cast: " << *shapeCast.getDefiningOp();
568  LDBG() << " -> Pattern match successful, returning result";
569 
570  return shapeCast;
571  }
572 };
573 
574 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
575 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
576 /// vector shapes are also reduced accordingly.
577 class TransferWriteDropUnitDimsPattern
578  : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
579  using MaskableOpRewritePattern::MaskableOpRewritePattern;
580 
581  FailureOr<Value>
582  matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
583  vector::MaskingOpInterface maskingOp,
584  PatternRewriter &rewriter) const override {
585  LDBG() << "=== TransferWriteDropUnitDimsPattern: Analyzing "
586  << *transferWriteOp;
587  auto loc = transferWriteOp.getLoc();
588  Value vector = transferWriteOp.getVector();
589  VectorType vectorType = cast<VectorType>(vector.getType());
590  Value source = transferWriteOp.getBase();
591  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
592  // TODO: support tensor type.
593  if (!sourceType) {
594  LDBG() << " -> Not a MemRefType, skipping";
595  return failure();
596  }
597  // TODO: generalize this pattern, relax the requirements here.
598  if (transferWriteOp.hasOutOfBoundsDim()) {
599  LDBG() << " -> Has out-of-bounds dimensions, skipping";
600  return failure();
601  }
602  if (!transferWriteOp.getPermutationMap().isMinorIdentity()) {
603  LDBG() << " -> Not minor identity permutation map, skipping";
604  return failure();
605  }
606  // Check if the destination shape can be further reduced.
607  int reducedRank = getReducedRank(sourceType.getShape());
608  LDBG() << " -> Source rank: " << sourceType.getRank()
609  << ", Reduced rank: " << reducedRank;
610  if (reducedRank == sourceType.getRank()) {
611  LDBG() << " -> No unit dimensions to drop, skipping";
612  return failure();
613  }
614  // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
615  // out.
616  if (reducedRank == 0 && maskingOp) {
617  LDBG() << " -> 0-d vector with masking not supported, skipping";
618  return failure();
619  }
620  // Check if the reduced vector shape matches the reduced destination shape.
621  // Otherwise, this case is not supported yet.
622  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
623  LDBG() << " -> Vector type: " << vectorType
624  << ", Reduced vector type: " << reducedVectorType;
625  if (reducedRank != reducedVectorType.getRank()) {
626  LDBG() << " -> Reduced ranks don't match, skipping";
627  return failure();
628  }
629  if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
630  return getConstantIntValue(v) != static_cast<int64_t>(0);
631  })) {
632  LDBG() << " -> Non-zero indices found, skipping";
633  return failure();
634  }
635 
636  Value maskOp = transferWriteOp.getMask();
637  if (maskOp) {
638  LDBG() << " -> Processing mask operation";
639  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
640  if (!createMaskOp) {
641  LDBG()
642  << " -> Unsupported mask op, only 'vector.create_mask' supported";
643  return rewriter.notifyMatchFailure(
644  transferWriteOp,
645  "unsupported mask op, only 'vector.create_mask' is "
646  "currently supported");
647  }
648  FailureOr<Value> rankReducedCreateMask =
649  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
650  if (failed(rankReducedCreateMask)) {
651  LDBG() << " -> Failed to reduce mask dimensions";
652  return failure();
653  }
654  maskOp = *rankReducedCreateMask;
655  LDBG() << " -> Successfully reduced mask dimensions";
656  }
657  LDBG() << " -> Creating rank-reduced subview and new transfer_write";
658  Value reducedShapeSource =
659  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
660  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
661  SmallVector<Value> zeros(reducedRank, c0);
662  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
663  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
664  auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
665  loc, reducedVectorType, vector);
666  Operation *newXferWrite = vector::TransferWriteOp::create(
667  rewriter, loc, Type(), shapeCastSrc, reducedShapeSource, zeros,
668  identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
669  LDBG() << " -> Created new transfer_write: " << *newXferWrite;
670 
671  if (maskingOp) {
672  LDBG() << " -> Applying masking operation";
673  auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
674  loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
675  maskingOp.getMask());
676  newXferWrite =
677  mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
678  }
679 
680  if (transferWriteOp.hasPureTensorSemantics()) {
681  LDBG() << " -> Pattern match successful (tensor semantics), returning "
682  "result";
683  return newXferWrite->getResults()[0];
684  }
685 
686  // With Memref semantics, there's no return value. Use empty value to signal
687  // success.
688  LDBG() << " -> Pattern match successful (memref semantics)";
689  return Value();
690  }
691 };
692 
693 } // namespace
694 
695 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
696 /// input starting at `firstDimToCollapse`.
698  Value input, int64_t firstDimToCollapse) {
699  ShapedType inputType = cast<ShapedType>(input.getType());
700  if (inputType.getRank() == 1)
701  return input;
702  SmallVector<ReassociationIndices> reassociation;
703  for (int64_t i = 0; i < firstDimToCollapse; ++i)
704  reassociation.push_back(ReassociationIndices{i});
705  ReassociationIndices collapsedIndices;
706  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
707  collapsedIndices.push_back(i);
708  reassociation.push_back(collapsedIndices);
709  return memref::CollapseShapeOp::create(rewriter, loc, input, reassociation);
710 }
711 
712 /// Returns the new indices that collapses the inner dimensions starting from
713 /// the `firstDimToCollapse` dimension.
715  Location loc,
716  ArrayRef<int64_t> shape,
717  ValueRange indices,
718  int64_t firstDimToCollapse) {
719  assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
720 
721  // If all the collapsed indices are zero then no extra logic is needed.
722  // Otherwise, a new offset/index has to be computed.
723  SmallVector<Value> indicesAfterCollapsing(
724  indices.begin(), indices.begin() + firstDimToCollapse);
725  SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
726  indices.end());
727  if (llvm::all_of(indicesToCollapse, isZeroInteger)) {
728  indicesAfterCollapsing.push_back(indicesToCollapse[0]);
729  return indicesAfterCollapsing;
730  }
731 
732  // Compute the remaining trailing index/offset required for reading from
733  // the collapsed memref:
734  //
735  // offset = 0
736  // for (i = firstDimToCollapse; i < outputRank; ++i)
737  // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
738  //
739  // For this example:
740  // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
741  // memref<1x43x2xi32>, vector<1x2xi32>
742  // which would be collapsed to:
743  // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
744  // memref<1x86xi32>, vector<2xi32>
745  // one would get the following offset:
746  // %offset = %arg0 * 43
747  OpFoldResult collapsedOffset =
748  arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
749 
750  auto collapsedStrides = computeSuffixProduct(
751  ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
752 
753  // Compute the collapsed offset.
754  auto &&[collapsedExpr, collapsedVals] =
755  computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
756  collapsedOffset = affine::makeComposedFoldedAffineApply(
757  rewriter, loc, collapsedExpr, collapsedVals);
758 
759  if (auto value = dyn_cast<Value>(collapsedOffset)) {
760  indicesAfterCollapsing.push_back(value);
761  } else {
762  indicesAfterCollapsing.push_back(arith::ConstantIndexOp::create(
763  rewriter, loc, *getConstantIntValue(collapsedOffset)));
764  }
765 
766  return indicesAfterCollapsing;
767 }
768 
769 namespace {
770 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
771 /// memref.collapse_shape on the source so that the resulting
772 /// vector.transfer_read has a 1D source. Requires the source shape to be
773 /// already reduced i.e. without unit dims.
774 ///
775 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
776 /// the trailing dimension of the vector read is smaller than the provided
777 /// bitwidth.
778 class FlattenContiguousRowMajorTransferReadPattern
779  : public OpRewritePattern<vector::TransferReadOp> {
780 public:
781  FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
782  unsigned vectorBitwidth,
783  PatternBenefit benefit)
784  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
785  targetVectorBitwidth(vectorBitwidth) {}
786 
787  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
788  PatternRewriter &rewriter) const override {
789  LDBG() << "=== FlattenContiguousRowMajorTransferReadPattern: Analyzing "
790  << *transferReadOp;
791  auto loc = transferReadOp.getLoc();
792  Value vector = transferReadOp.getVector();
793  VectorType vectorType = cast<VectorType>(vector.getType());
794  auto source = transferReadOp.getBase();
795  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
796 
797  // 0. Check pre-conditions
798  // Contiguity check is valid on tensors only.
799  if (!sourceType) {
800  LDBG() << " -> Not a MemRefType, skipping";
801  return failure();
802  }
803  // If this is already 0D/1D, there's nothing to do.
804  if (vectorType.getRank() <= 1) {
805  LDBG() << " -> Already 0D/1D, skipping";
806  return failure();
807  }
808  if (!vectorType.getElementType().isSignlessIntOrFloat()) {
809  LDBG() << " -> Not signless int or float, skipping";
810  return failure();
811  }
812  unsigned trailingVectorDimBitwidth =
813  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
814  LDBG() << " -> Trailing vector dim bitwidth: " << trailingVectorDimBitwidth
815  << ", target: " << targetVectorBitwidth;
816  if (trailingVectorDimBitwidth >= targetVectorBitwidth) {
817  LDBG() << " -> Trailing dim bitwidth >= target, skipping";
818  return failure();
819  }
820  if (!vector::isContiguousSlice(sourceType, vectorType)) {
821  LDBG() << " -> Not contiguous slice, skipping";
822  return failure();
823  }
824  // TODO: generalize this pattern, relax the requirements here.
825  if (transferReadOp.hasOutOfBoundsDim()) {
826  LDBG() << " -> Has out-of-bounds dimensions, skipping";
827  return failure();
828  }
829  if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
830  LDBG() << " -> Not minor identity permutation map, skipping";
831  return failure();
832  }
833  if (transferReadOp.getMask()) {
834  LDBG() << " -> Has mask, skipping";
835  return failure();
836  }
837 
838  // Determine the first memref dimension to collapse - just enough so we can
839  // read a flattened vector.
840  int64_t firstDimToCollapse =
841  sourceType.getRank() -
842  vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
843  LDBG() << " -> First dimension to collapse: " << firstDimToCollapse;
844 
845  // 1. Collapse the source memref
846  LDBG() << " -> Collapsing source memref";
847  Value collapsedSource =
848  collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
849  MemRefType collapsedSourceType =
850  cast<MemRefType>(collapsedSource.getType());
851  int64_t collapsedRank = collapsedSourceType.getRank();
852  assert(collapsedRank == firstDimToCollapse + 1);
853  LDBG() << " -> Collapsed source type: " << collapsedSourceType;
854 
855  // 2. Generate input args for a new vector.transfer_read that will read
856  // from the collapsed memref.
857  // 2.1. New dim exprs + affine map
859  getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
860  auto collapsedMap =
861  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
862 
863  // 2.2 New indices
864  SmallVector<Value> collapsedIndices =
865  getCollapsedIndices(rewriter, loc, sourceType.getShape(),
866  transferReadOp.getIndices(), firstDimToCollapse);
867 
868  // 3. Create new vector.transfer_read that reads from the collapsed memref
869  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
870  vectorType.getElementType());
871  LDBG() << " -> Creating flattened vector type: " << flatVectorType;
872  vector::TransferReadOp flatRead = vector::TransferReadOp::create(
873  rewriter, loc, flatVectorType, collapsedSource, collapsedIndices,
874  transferReadOp.getPadding(), collapsedMap);
875  flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
876  LDBG() << " -> Created flat transfer_read: " << *flatRead;
877 
878  // 4. Replace the old transfer_read with the new one reading from the
879  // collapsed shape
880  LDBG() << " -> Replacing with shape cast";
881  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
882  transferReadOp, cast<VectorType>(vector.getType()), flatRead);
883  LDBG() << " -> Pattern match successful";
884  return success();
885  }
886 
887 private:
888  // Minimum bitwidth that the trailing vector dimension should have after
889  // flattening.
890  unsigned targetVectorBitwidth;
891 };
892 
893 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
894 /// memref.collapse_shape on the source so that the resulting
895 /// vector.transfer_write has a 1D source. Requires the source shape to be
896 /// already reduced i.e. without unit dims.
897 ///
898 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
899 /// the trailing dimension of the vector read is smaller than the provided
900 /// bitwidth.
901 class FlattenContiguousRowMajorTransferWritePattern
902  : public OpRewritePattern<vector::TransferWriteOp> {
903 public:
904  FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
905  unsigned vectorBitwidth,
906  PatternBenefit benefit)
907  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
908  targetVectorBitwidth(vectorBitwidth) {}
909 
910  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
911  PatternRewriter &rewriter) const override {
912  auto loc = transferWriteOp.getLoc();
913  Value vector = transferWriteOp.getVector();
914  VectorType vectorType = cast<VectorType>(vector.getType());
915  Value source = transferWriteOp.getBase();
916  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
917 
918  // 0. Check pre-conditions
919  // Contiguity check is valid on tensors only.
920  if (!sourceType)
921  return failure();
922  // If this is already 0D/1D, there's nothing to do.
923  if (vectorType.getRank() <= 1)
924  // Already 0D/1D, nothing to do.
925  return failure();
926  if (!vectorType.getElementType().isSignlessIntOrFloat())
927  return failure();
928  unsigned trailingVectorDimBitwidth =
929  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
930  if (trailingVectorDimBitwidth >= targetVectorBitwidth)
931  return failure();
932  if (!vector::isContiguousSlice(sourceType, vectorType))
933  return failure();
934  // TODO: generalize this pattern, relax the requirements here.
935  if (transferWriteOp.hasOutOfBoundsDim())
936  return failure();
937  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
938  return failure();
939  if (transferWriteOp.getMask())
940  return failure();
941 
942  // Determine the first memref dimension to collapse - just enough so we can
943  // read a flattened vector.
944  int64_t firstDimToCollapse =
945  sourceType.getRank() -
946  vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
947 
948  // 1. Collapse the source memref
949  Value collapsedSource =
950  collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
951  MemRefType collapsedSourceType =
952  cast<MemRefType>(collapsedSource.getType());
953  int64_t collapsedRank = collapsedSourceType.getRank();
954  assert(collapsedRank == firstDimToCollapse + 1);
955 
956  // 2. Generate input args for a new vector.transfer_read that will read
957  // from the collapsed memref.
958  // 2.1. New dim exprs + affine map
960  getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
961  auto collapsedMap =
962  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
963 
964  // 2.2 New indices
965  SmallVector<Value> collapsedIndices =
966  getCollapsedIndices(rewriter, loc, sourceType.getShape(),
967  transferWriteOp.getIndices(), firstDimToCollapse);
968 
969  // 3. Create new vector.transfer_write that writes to the collapsed memref
970  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
971  vectorType.getElementType());
972  Value flatVector =
973  vector::ShapeCastOp::create(rewriter, loc, flatVectorType, vector);
974  vector::TransferWriteOp flatWrite = vector::TransferWriteOp::create(
975  rewriter, loc, flatVector, collapsedSource, collapsedIndices,
976  collapsedMap);
977  flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
978 
979  // 4. Replace the old transfer_write with the new one writing the
980  // collapsed shape
981  rewriter.eraseOp(transferWriteOp);
982  return success();
983  }
984 
985 private:
986  // Minimum bitwidth that the trailing vector dimension should have after
987  // flattening.
988  unsigned targetVectorBitwidth;
989 };
990 
991 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
992 ///
993 /// All the users of the transfer op must be `vector.extract` ops. If
994 /// `allowMultipleUses` is set to true, rewrite transfer ops with any number of
995 /// users. Otherwise, rewrite only if the extract op is the single user of the
996 /// transfer op. Rewriting a single vector load with multiple scalar loads may
997 /// negatively affect performance.
998 class RewriteScalarExtractOfTransferRead
999  : public OpRewritePattern<vector::ExtractOp> {
1000 public:
1001  RewriteScalarExtractOfTransferRead(MLIRContext *context,
1002  PatternBenefit benefit,
1003  bool allowMultipleUses)
1004  : OpRewritePattern(context, benefit),
1005  allowMultipleUses(allowMultipleUses) {}
1006 
1007  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1008  PatternRewriter &rewriter) const override {
1009  // Match phase.
1010  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
1011  if (!xferOp)
1012  return failure();
1013  // Check that we are extracting a scalar and not a sub-vector.
1014  if (isa<VectorType>(extractOp.getResult().getType()))
1015  return failure();
1016  // If multiple uses are not allowed, check if xfer has a single use.
1017  if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
1018  return failure();
1019  // If multiple uses are allowed, check if all the xfer uses are extract ops.
1020  if (allowMultipleUses &&
1021  !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
1022  return isa<vector::ExtractOp>(use.getOwner());
1023  }))
1024  return failure();
1025  // Mask not supported.
1026  if (xferOp.getMask())
1027  return failure();
1028  // Map not supported.
1029  if (!xferOp.getPermutationMap().isMinorIdentity())
1030  return failure();
1031  // Cannot rewrite if the indices may be out of bounds.
1032  if (xferOp.hasOutOfBoundsDim())
1033  return failure();
1034 
1035  // Rewrite phase: construct scalar load.
1036  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
1037  xferOp.getIndices().end());
1038  for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1039  int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
1040 
1041  // Compute affine expression `newIndices[idx] + pos` where `pos` can be
1042  // either a constant or a value.
1043  OpFoldResult composedIdx;
1044  if (auto attr = dyn_cast<Attribute>(pos)) {
1045  int64_t offset = cast<IntegerAttr>(attr).getInt();
1047  rewriter, extractOp.getLoc(),
1048  rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
1049  } else {
1050  Value dynamicOffset = cast<Value>(pos);
1051  AffineExpr sym0, sym1;
1052  bindSymbols(rewriter.getContext(), sym0, sym1);
1054  rewriter, extractOp.getLoc(), sym0 + sym1,
1055  {newIndices[idx], dynamicOffset});
1056  }
1057 
1058  // Update the corresponding index with the folded result.
1059  if (auto value = dyn_cast<Value>(composedIdx)) {
1060  newIndices[idx] = value;
1061  } else {
1062  newIndices[idx] = arith::ConstantIndexOp::create(
1063  rewriter, extractOp.getLoc(), *getConstantIntValue(composedIdx));
1064  }
1065  }
1066  if (isa<MemRefType>(xferOp.getBase().getType())) {
1067  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
1068  newIndices);
1069  } else {
1070  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1071  extractOp, xferOp.getBase(), newIndices);
1072  }
1073 
1074  return success();
1075  }
1076 
1077 private:
1078  bool allowMultipleUses;
1079 };
1080 
1081 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
1082 /// to memref.store.
1083 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
1085 
1086  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1087  PatternRewriter &rewriter) const override {
1088  // Must be a scalar write.
1089  auto vecType = xferOp.getVectorType();
1090  if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
1091  return failure();
1092  // Mask not supported.
1093  if (xferOp.getMask())
1094  return failure();
1095  // Map not supported.
1096  if (!xferOp.getPermutationMap().isMinorIdentity())
1097  return failure();
1098  // Only float and integer element types are supported.
1099  Value scalar = vector::ExtractOp::create(rewriter, xferOp.getLoc(),
1100  xferOp.getVector());
1101  // Construct a scalar store.
1102  if (isa<MemRefType>(xferOp.getBase().getType())) {
1103  rewriter.replaceOpWithNewOp<memref::StoreOp>(
1104  xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1105  } else {
1106  rewriter.replaceOpWithNewOp<tensor::InsertOp>(
1107  xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1108  }
1109  return success();
1110  }
1111 };
1112 
1113 } // namespace
1114 
1116  Operation *rootOp) {
1117  LDBG() << "=== Starting transferOpflowOpt on root operation: "
1118  << OpWithFlags(rootOp, OpPrintingFlags().skipRegions());
1119  TransferOptimization opt(rewriter, rootOp);
1120 
1121  // Run store to load forwarding first since it can expose more dead store
1122  // opportunity.
1123  LDBG() << "Phase 1: Store-to-load forwarding";
1124  int readCount = 0;
1125  rootOp->walk([&](vector::TransferReadOp read) {
1126  if (isa<MemRefType>(read.getShapedType())) {
1127  LDBG() << "Processing transfer_read #" << ++readCount << ": " << *read;
1128  opt.storeToLoadForwarding(read);
1129  }
1130  });
1131  LDBG() << "Phase 1 complete. Removing dead operations from forwarding";
1132  opt.removeDeadOp();
1133 
1134  LDBG() << "Phase 2: Dead store elimination";
1135  int writeCount = 0;
1136  rootOp->walk([&](vector::TransferWriteOp write) {
1137  if (isa<MemRefType>(write.getShapedType())) {
1138  LDBG() << "Processing transfer_write #" << ++writeCount << ": " << *write;
1139  opt.deadStoreOp(write);
1140  }
1141  });
1142  LDBG() << "Phase 2 complete. Removing dead operations from dead store "
1143  "elimination";
1144  opt.removeDeadOp();
1145  LDBG() << "=== transferOpflowOpt complete";
1146 }
1147 
1150  bool allowMultipleUses) {
1151  patterns.add<RewriteScalarExtractOfTransferRead>(patterns.getContext(),
1152  benefit, allowMultipleUses);
1153  patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
1154 }
1155 
1156 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
1158  patterns
1159  .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
1160  patterns.getContext(), benefit);
1161 }
1162 
1163 void mlir::vector::populateFlattenVectorTransferPatterns(
1164  RewritePatternSet &patterns, unsigned targetVectorBitwidth,
1165  PatternBenefit benefit) {
1166  patterns.add<FlattenContiguousRowMajorTransferReadPattern,
1167  FlattenContiguousRowMajorTransferWritePattern>(
1168  patterns.getContext(), targetVectorBitwidth, benefit);
1169  populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
1170 }
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input)
Creates a rank-reducing memref.subview op that drops unit dims from its input.
static int getReducedRank(ArrayRef< int64_t > shape)
Returns the number of dims that aren't unit dims.
static SmallVector< int64_t > getReducedShape(ArrayRef< OpFoldResult > mixedSizes)
Converts OpFoldResults to int64_t shape without unit dims.
static VectorType trimNonScalableUnitDims(VectorType oldType)
Trims non-scalable one dimensions from oldType and returns the result type.
static FailureOr< Value > createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, vector::CreateMaskOp op)
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
static SmallVector< Value > getCollapsedIndices(RewriterBase &rewriter, Location loc, ArrayRef< int64_t > shape, ValueRange indices, int64_t firstDimToCollapse)
Returns the new indices that collapses the inner dimensions starting from the firstDimToCollapse dime...
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
Drops unit dimensions from the input MemRefType.
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
Definition: Block.cpp:363
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:386
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:367
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:52
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:269
A class for computing basic dominance information.
Definition: Dominance.h:140
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
A class for computing basic postdominance information.
Definition: Dominance.h:204
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
Definition: MemRefUtils.h:111
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:314
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:163