MLIR  16.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/Dominance.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Debug.h"
25 
26 #define DEBUG_TYPE "vector-transfer-opt"
27 
28 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
29 
30 using namespace mlir;
31 
32 /// Return the ancestor op in the region or nullptr if the region is not
33 /// an ancestor of the op.
35  for (; op != nullptr && op->getParentRegion() != region;
36  op = op->getParentOp())
37  ;
38  return op;
39 }
40 
41 namespace {
42 
43 class TransferOptimization {
44 public:
45  TransferOptimization(Operation *op) : dominators(op), postDominators(op) {}
46  void deadStoreOp(vector::TransferWriteOp);
47  void storeToLoadForwarding(vector::TransferReadOp);
48  void removeDeadOp() {
49  for (Operation *op : opToErase)
50  op->erase();
51  opToErase.clear();
52  }
53 
54 private:
55  bool isReachable(Operation *start, Operation *dest);
56  DominanceInfo dominators;
57  PostDominanceInfo postDominators;
58  std::vector<Operation *> opToErase;
59 };
60 
61 /// Return true if there is a path from start operation to dest operation,
62 /// otherwise return false. The operations have to be in the same region.
63 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
64  assert(start->getParentRegion() == dest->getParentRegion() &&
65  "This function only works for ops i the same region");
66  // Simple case where the start op dominate the destination.
67  if (dominators.dominates(start, dest))
68  return true;
69  Block *startBlock = start->getBlock();
70  Block *destBlock = dest->getBlock();
71  SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
72  startBlock->succ_end());
74  while (!worklist.empty()) {
75  Block *bb = worklist.pop_back_val();
76  if (!visited.insert(bb).second)
77  continue;
78  if (dominators.dominates(bb, destBlock))
79  return true;
80  worklist.append(bb->succ_begin(), bb->succ_end());
81  }
82  return false;
83 }
84 
85 /// For transfer_write to overwrite fully another transfer_write must:
86 /// 1. Access the same memref with the same indices and vector type.
87 /// 2. Post-dominate the other transfer_write operation.
88 /// If several candidates are available, one must be post-dominated by all the
89 /// others since they are all post-dominating the same transfer_write. We only
90 /// consider the transfer_write post-dominated by all the other candidates as
91 /// this will be the first transfer_write executed after the potentially dead
92 /// transfer_write.
93 /// If we found such an overwriting transfer_write we know that the original
94 /// transfer_write is dead if all reads that can be reached from the potentially
95 /// dead transfer_write are dominated by the overwriting transfer_write.
96 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
97  LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
98  << "\n");
99  llvm::SmallVector<Operation *, 8> blockingAccesses;
100  Operation *firstOverwriteCandidate = nullptr;
101  Value source = write.getSource();
102  // Skip subview ops.
103  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
104  source = subView.getSource();
105  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
106  source.getUsers().end());
107  llvm::SmallDenseSet<Operation *, 32> processed;
108  while (!users.empty()) {
109  Operation *user = users.pop_back_val();
110  // If the user has already been processed skip.
111  if (!processed.insert(user).second)
112  continue;
113  if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
114  users.append(subView->getUsers().begin(), subView->getUsers().end());
115  continue;
116  }
117  if (isMemoryEffectFree(user))
118  continue;
119  if (user == write.getOperation())
120  continue;
121  if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
122  // Check candidate that can override the store.
123  if (write.getSource() == nextWrite.getSource() &&
124  checkSameValueWAW(nextWrite, write) &&
125  postDominators.postDominates(nextWrite, write)) {
126  if (firstOverwriteCandidate == nullptr ||
127  postDominators.postDominates(firstOverwriteCandidate, nextWrite))
128  firstOverwriteCandidate = nextWrite;
129  else
130  assert(
131  postDominators.postDominates(nextWrite, firstOverwriteCandidate));
132  continue;
133  }
134  }
135  if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
136  // Don't need to consider disjoint accesses.
138  cast<VectorTransferOpInterface>(write.getOperation()),
139  cast<VectorTransferOpInterface>(transferOp.getOperation())))
140  continue;
141  }
142  blockingAccesses.push_back(user);
143  }
144  if (firstOverwriteCandidate == nullptr)
145  return;
146  Region *topRegion = firstOverwriteCandidate->getParentRegion();
147  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
148  assert(writeAncestor &&
149  "write op should be recursively part of the top region");
150 
151  for (Operation *access : blockingAccesses) {
152  Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
153  // TODO: if the access and write have the same ancestor we could recurse in
154  // the region to know if the access is reachable with more precision.
155  if (accessAncestor == nullptr ||
156  !isReachable(writeAncestor, accessAncestor))
157  continue;
158  if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
159  LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
160  << *accessAncestor << "\n");
161  return;
162  }
163  }
164  LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
165  << " overwritten by: " << *firstOverwriteCandidate << "\n");
166  opToErase.push_back(write.getOperation());
167 }
168 
169 /// A transfer_write candidate to storeToLoad forwarding must:
170 /// 1. Access the same memref with the same indices and vector type as the
171 /// transfer_read.
172 /// 2. Dominate the transfer_read operation.
173 /// If several candidates are available, one must be dominated by all the others
174 /// since they are all dominating the same transfer_read. We only consider the
175 /// transfer_write dominated by all the other candidates as this will be the
176 /// last transfer_write executed before the transfer_read.
177 /// If we found such a candidate we can do the forwarding if all the other
178 /// potentially aliasing ops that may reach the transfer_read are post-dominated
179 /// by the transfer_write.
180 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
181  if (read.hasOutOfBoundsDim())
182  return;
183  LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
184  << "\n");
185  SmallVector<Operation *, 8> blockingWrites;
186  vector::TransferWriteOp lastwrite = nullptr;
187  Value source = read.getSource();
188  // Skip subview ops.
189  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
190  source = subView.getSource();
191  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
192  source.getUsers().end());
193  llvm::SmallDenseSet<Operation *, 32> processed;
194  while (!users.empty()) {
195  Operation *user = users.pop_back_val();
196  // If the user has already been processed skip.
197  if (!processed.insert(user).second)
198  continue;
199  if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
200  users.append(subView->getUsers().begin(), subView->getUsers().end());
201  continue;
202  }
203  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
204  continue;
205  if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
206  // If there is a write, but we can prove that it is disjoint we can ignore
207  // the write.
209  cast<VectorTransferOpInterface>(write.getOperation()),
210  cast<VectorTransferOpInterface>(read.getOperation())))
211  continue;
212  if (write.getSource() == read.getSource() &&
213  dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
214  if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
215  lastwrite = write;
216  else
217  assert(dominators.dominates(write, lastwrite));
218  continue;
219  }
220  }
221  blockingWrites.push_back(user);
222  }
223 
224  if (lastwrite == nullptr)
225  return;
226 
227  Region *topRegion = lastwrite->getParentRegion();
228  Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
229  assert(readAncestor &&
230  "read op should be recursively part of the top region");
231 
232  for (Operation *write : blockingWrites) {
233  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
234  // TODO: if the store and read have the same ancestor we could recurse in
235  // the region to know if the read is reachable with more precision.
236  if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
237  continue;
238  if (!postDominators.postDominates(lastwrite, write)) {
239  LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
240  << *write << "\n");
241  return;
242  }
243  }
244 
245  LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
246  << " to: " << *read.getOperation() << "\n");
247  read.replaceAllUsesWith(lastwrite.getVector());
248  opToErase.push_back(read.getOperation());
249 }
250 
251 /// Drops unit dimensions from the input MemRefType.
252 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
253  ArrayRef<int64_t> sizes,
254  ArrayRef<int64_t> strides) {
255  SmallVector<int64_t> targetShape = llvm::to_vector(
256  llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
257  Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
258  targetShape, inputType, offsets, sizes, strides);
259  return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
260 }
261 
262 /// Creates a rank-reducing memref.subview op that drops unit dims from its
263 /// input. Or just returns the input if it was already without unit dims.
264 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
265  mlir::Location loc,
266  Value input) {
267  MemRefType inputType = input.getType().cast<MemRefType>();
268  assert(inputType.hasStaticShape());
269  SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
270  SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
271  ArrayRef<int64_t> subViewSizes = inputType.getShape();
272  MemRefType resultType =
273  dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
274  if (canonicalizeStridedLayout(resultType) ==
275  canonicalizeStridedLayout(inputType))
276  return input;
277  return rewriter.create<memref::SubViewOp>(
278  loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
279 }
280 
281 /// Returns the number of dims that aren't unit dims.
282 static int getReducedRank(ArrayRef<int64_t> shape) {
283  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
284 }
285 
286 /// Returns true if all values are `arith.constant 0 : index`
287 static bool isZero(Value v) {
288  auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
289  return cst && cst.value() == 0;
290 }
291 
292 /// Rewrites vector.transfer_read ops where the source has unit dims, by
293 /// inserting a memref.subview dropping those unit dims.
294 class TransferReadDropUnitDimsPattern
295  : public OpRewritePattern<vector::TransferReadOp> {
297 
298  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
299  PatternRewriter &rewriter) const override {
300  auto loc = transferReadOp.getLoc();
301  Value vector = transferReadOp.getVector();
302  VectorType vectorType = vector.getType().cast<VectorType>();
303  Value source = transferReadOp.getSource();
304  MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
305  // TODO: support tensor types.
306  if (!sourceType || !sourceType.hasStaticShape())
307  return failure();
308  if (sourceType.getNumElements() != vectorType.getNumElements())
309  return failure();
310  // TODO: generalize this pattern, relax the requirements here.
311  if (transferReadOp.hasOutOfBoundsDim())
312  return failure();
313  if (!transferReadOp.getPermutationMap().isMinorIdentity())
314  return failure();
315  int reducedRank = getReducedRank(sourceType.getShape());
316  if (reducedRank == sourceType.getRank())
317  return failure(); // The source shape can't be further reduced.
318  if (reducedRank != vectorType.getRank())
319  return failure(); // This pattern requires the vector shape to match the
320  // reduced source shape.
321  if (llvm::any_of(transferReadOp.getIndices(),
322  [](Value v) { return !isZero(v); }))
323  return failure();
324  Value reducedShapeSource =
325  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
326  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
327  SmallVector<Value> zeros(reducedRank, c0);
328  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
329  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
330  transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
331  return success();
332  }
333 };
334 
335 /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
336 /// unit dims, by inserting a memref.subview dropping those unit dims.
337 class TransferWriteDropUnitDimsPattern
338  : public OpRewritePattern<vector::TransferWriteOp> {
340 
341  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
342  PatternRewriter &rewriter) const override {
343  auto loc = transferWriteOp.getLoc();
344  Value vector = transferWriteOp.getVector();
345  VectorType vectorType = vector.getType().cast<VectorType>();
346  Value source = transferWriteOp.getSource();
347  MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
348  // TODO: support tensor type.
349  if (!sourceType || !sourceType.hasStaticShape())
350  return failure();
351  if (sourceType.getNumElements() != vectorType.getNumElements())
352  return failure();
353  // TODO: generalize this pattern, relax the requirements here.
354  if (transferWriteOp.hasOutOfBoundsDim())
355  return failure();
356  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
357  return failure();
358  int reducedRank = getReducedRank(sourceType.getShape());
359  if (reducedRank == sourceType.getRank())
360  return failure(); // The source shape can't be further reduced.
361  if (reducedRank != vectorType.getRank())
362  return failure(); // This pattern requires the vector shape to match the
363  // reduced source shape.
364  if (llvm::any_of(transferWriteOp.getIndices(),
365  [](Value v) { return !isZero(v); }))
366  return failure();
367  Value reducedShapeSource =
368  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
369  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
370  SmallVector<Value> zeros(reducedRank, c0);
371  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
372  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
373  transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
374  return success();
375  }
376 };
377 
378 /// Return true if the memref type has its inner dimension matching the given
379 /// shape. Otherwise return false.
380 static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
381  ArrayRef<int64_t> targetShape) {
382  auto shape = memrefType.getShape();
383  SmallVector<int64_t> strides;
384  int64_t offset;
385  if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
386  return false;
387  if (strides.back() != 1)
388  return false;
389  strides.pop_back();
390  int64_t flatDim = 1;
391  for (auto [targetDim, memrefDim, memrefStride] :
392  llvm::reverse(llvm::zip(targetShape, shape, strides))) {
393  flatDim *= memrefDim;
394  if (flatDim != memrefStride || targetDim != memrefDim)
395  return false;
396  }
397  return true;
398 }
399 
400 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
401 /// input starting at `firstDimToCollapse`.
402 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
403  Value input, int64_t firstDimToCollapse) {
404  ShapedType inputType = input.getType().cast<ShapedType>();
405  if (inputType.getRank() == 1)
406  return input;
407  SmallVector<ReassociationIndices> reassociation;
408  for (int64_t i = 0; i < firstDimToCollapse; ++i)
409  reassociation.push_back(ReassociationIndices{i});
410  ReassociationIndices collapsedIndices;
411  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
412  collapsedIndices.push_back(i);
413  reassociation.push_back(collapsedIndices);
414  return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
415 }
416 
417 /// Checks that the indices corresponding to dimensions starting at
418 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
419 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
420 static LogicalResult
421 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
422  SmallVector<Value> &outIndices) {
423  int64_t rank = indices.size();
424  if (firstDimToCollapse >= rank)
425  return failure();
426  for (int64_t i = firstDimToCollapse; i < rank; ++i) {
427  arith::ConstantIndexOp cst =
428  indices[i].getDefiningOp<arith::ConstantIndexOp>();
429  if (!cst || cst.value() != 0)
430  return failure();
431  }
432  outIndices = indices;
433  outIndices.resize(firstDimToCollapse + 1);
434  return success();
435 }
436 
437 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
438 /// memref.collapse_shape on the source so that the resulting
439 /// vector.transfer_read has a 1D source. Requires the source shape to be
440 /// already reduced i.e. without unit dims.
441 class FlattenContiguousRowMajorTransferReadPattern
442  : public OpRewritePattern<vector::TransferReadOp> {
444 
445  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
446  PatternRewriter &rewriter) const override {
447  auto loc = transferReadOp.getLoc();
448  Value vector = transferReadOp.getVector();
449  VectorType vectorType = vector.getType().cast<VectorType>();
450  Value source = transferReadOp.getSource();
451  MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
452  // Contiguity check is valid on tensors only.
453  if (!sourceType)
454  return failure();
455  if (vectorType.getRank() <= 1)
456  // Already 0D/1D, nothing to do.
457  return failure();
458  if (!hasMatchingInnerContigousShape(
459  sourceType,
460  vectorType.getShape().take_back(vectorType.getRank() - 1)))
461  return failure();
462  int64_t firstContiguousInnerDim =
463  sourceType.getRank() - vectorType.getRank();
464  // TODO: generalize this pattern, relax the requirements here.
465  if (transferReadOp.hasOutOfBoundsDim())
466  return failure();
467  if (!transferReadOp.getPermutationMap().isMinorIdentity())
468  return failure();
469  if (transferReadOp.getMask())
470  return failure();
471  SmallVector<Value> collapsedIndices;
472  if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
473  firstContiguousInnerDim,
474  collapsedIndices)))
475  return failure();
476  Value collapsedSource =
477  collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
478  MemRefType collapsedSourceType =
479  collapsedSource.getType().dyn_cast<MemRefType>();
480  int64_t collapsedRank = collapsedSourceType.getRank();
481  assert(collapsedRank == firstContiguousInnerDim + 1);
483  getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
484  auto collapsedMap =
485  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
486  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
487  vectorType.getElementType());
488  vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
489  loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
490  flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
491  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
492  transferReadOp, vector.getType().cast<VectorType>(), flatRead);
493  return success();
494  }
495 };
496 
497 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
498 /// memref.collapse_shape on the source so that the resulting
499 /// vector.transfer_write has a 1D source. Requires the source shape to be
500 /// already reduced i.e. without unit dims.
501 class FlattenContiguousRowMajorTransferWritePattern
502  : public OpRewritePattern<vector::TransferWriteOp> {
504 
505  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
506  PatternRewriter &rewriter) const override {
507  auto loc = transferWriteOp.getLoc();
508  Value vector = transferWriteOp.getVector();
509  VectorType vectorType = vector.getType().cast<VectorType>();
510  Value source = transferWriteOp.getSource();
511  MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
512  // Contiguity check is valid on tensors only.
513  if (!sourceType)
514  return failure();
515  if (vectorType.getRank() <= 1)
516  // Already 0D/1D, nothing to do.
517  return failure();
518  if (!hasMatchingInnerContigousShape(
519  sourceType,
520  vectorType.getShape().take_back(vectorType.getRank() - 1)))
521  return failure();
522  int64_t firstContiguousInnerDim =
523  sourceType.getRank() - vectorType.getRank();
524  // TODO: generalize this pattern, relax the requirements here.
525  if (transferWriteOp.hasOutOfBoundsDim())
526  return failure();
527  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
528  return failure();
529  if (transferWriteOp.getMask())
530  return failure();
531  SmallVector<Value> collapsedIndices;
532  if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
533  firstContiguousInnerDim,
534  collapsedIndices)))
535  return failure();
536  Value collapsedSource =
537  collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
538  MemRefType collapsedSourceType =
539  collapsedSource.getType().cast<MemRefType>();
540  int64_t collapsedRank = collapsedSourceType.getRank();
541  assert(collapsedRank == firstContiguousInnerDim + 1);
543  getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
544  auto collapsedMap =
545  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
546  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
547  vectorType.getElementType());
548  Value flatVector =
549  rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
550  vector::TransferWriteOp flatWrite =
551  rewriter.create<vector::TransferWriteOp>(
552  loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
553  flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
554  rewriter.eraseOp(transferWriteOp);
555  return success();
556  }
557 };
558 
559 } // namespace
560 
562  TransferOptimization opt(rootOp);
563  // Run store to load forwarding first since it can expose more dead store
564  // opportunity.
565  rootOp->walk([&](vector::TransferReadOp read) {
566  if (read.getShapedType().isa<MemRefType>())
567  opt.storeToLoadForwarding(read);
568  });
569  opt.removeDeadOp();
570  rootOp->walk([&](vector::TransferWriteOp write) {
571  if (write.getShapedType().isa<MemRefType>())
572  opt.deadStoreOp(write);
573  });
574  opt.removeDeadOp();
575 }
576 
578  RewritePatternSet &patterns, PatternBenefit benefit) {
579  patterns
580  .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
581  patterns.getContext(), benefit);
583 }
584 
586  RewritePatternSet &patterns, PatternBenefit benefit) {
587  patterns.add<FlattenContiguousRowMajorTransferReadPattern,
588  FlattenContiguousRowMajorTransferWritePattern>(
589  patterns.getContext(), benefit);
590  populateShapeCastFoldingPatterns(patterns, benefit);
591 }
static bool isZero(OpFoldResult v)
Definition: Tiling.cpp:43
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
#define DBGS()
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:30
succ_iterator succ_end()
Definition: Block.h:252
succ_iterator succ_begin()
Definition: Block.h:251
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:350
MLIRContext * getContext() const
Definition: Builders.h:54
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:251
A class for computing basic dominance information.
Definition: Dominance.h:117
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:161
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:574
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
A class for computing basic postdominance information.
Definition: Dominance.h:176
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
U dyn_cast() const
Definition: Types.h:270
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
user_range getUsers() const
Definition: Value.h:209
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Return true if we can prove that the transfer operations access disjoint memory.
Definition: VectorOps.cpp:209
void transferOpflowOpt(Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:488
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360