MLIR  17.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Dominance.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Debug.h"
28 
29 #define DEBUG_TYPE "vector-transfer-opt"
30 
31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
32 
33 using namespace mlir;
34 
35 /// Return the ancestor op in the region or nullptr if the region is not
36 /// an ancestor of the op.
38  for (; op != nullptr && op->getParentRegion() != region;
39  op = op->getParentOp())
40  ;
41  return op;
42 }
43 
44 namespace {
45 
46 class TransferOptimization {
47 public:
48  TransferOptimization(RewriterBase &rewriter, Operation *op)
49  : rewriter(rewriter), dominators(op), postDominators(op) {}
50  void deadStoreOp(vector::TransferWriteOp);
51  void storeToLoadForwarding(vector::TransferReadOp);
52  void removeDeadOp() {
53  for (Operation *op : opToErase)
54  rewriter.eraseOp(op);
55  opToErase.clear();
56  }
57 
58 private:
59  RewriterBase &rewriter;
60  bool isReachable(Operation *start, Operation *dest);
61  DominanceInfo dominators;
62  PostDominanceInfo postDominators;
63  std::vector<Operation *> opToErase;
64 };
65 
66 /// Return true if there is a path from start operation to dest operation,
67 /// otherwise return false. The operations have to be in the same region.
68 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
69  assert(start->getParentRegion() == dest->getParentRegion() &&
70  "This function only works for ops i the same region");
71  // Simple case where the start op dominate the destination.
72  if (dominators.dominates(start, dest))
73  return true;
74  Block *startBlock = start->getBlock();
75  Block *destBlock = dest->getBlock();
76  SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
77  startBlock->succ_end());
79  while (!worklist.empty()) {
80  Block *bb = worklist.pop_back_val();
81  if (!visited.insert(bb).second)
82  continue;
83  if (dominators.dominates(bb, destBlock))
84  return true;
85  worklist.append(bb->succ_begin(), bb->succ_end());
86  }
87  return false;
88 }
89 
90 /// For transfer_write to overwrite fully another transfer_write must:
91 /// 1. Access the same memref with the same indices and vector type.
92 /// 2. Post-dominate the other transfer_write operation.
93 /// If several candidates are available, one must be post-dominated by all the
94 /// others since they are all post-dominating the same transfer_write. We only
95 /// consider the transfer_write post-dominated by all the other candidates as
96 /// this will be the first transfer_write executed after the potentially dead
97 /// transfer_write.
98 /// If we found such an overwriting transfer_write we know that the original
99 /// transfer_write is dead if all reads that can be reached from the potentially
100 /// dead transfer_write are dominated by the overwriting transfer_write.
101 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
102  LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
103  << "\n");
104  llvm::SmallVector<Operation *, 8> blockingAccesses;
105  Operation *firstOverwriteCandidate = nullptr;
106  Value source = write.getSource();
107  // Skip subview ops.
108  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
109  source = subView.getSource();
110  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
111  source.getUsers().end());
112  llvm::SmallDenseSet<Operation *, 32> processed;
113  while (!users.empty()) {
114  Operation *user = users.pop_back_val();
115  // If the user has already been processed skip.
116  if (!processed.insert(user).second)
117  continue;
118  if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
119  users.append(subView->getUsers().begin(), subView->getUsers().end());
120  continue;
121  }
122  if (isMemoryEffectFree(user))
123  continue;
124  if (user == write.getOperation())
125  continue;
126  if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
127  // Check candidate that can override the store.
128  if (write.getSource() == nextWrite.getSource() &&
129  checkSameValueWAW(nextWrite, write) &&
130  postDominators.postDominates(nextWrite, write)) {
131  if (firstOverwriteCandidate == nullptr ||
132  postDominators.postDominates(firstOverwriteCandidate, nextWrite))
133  firstOverwriteCandidate = nextWrite;
134  else
135  assert(
136  postDominators.postDominates(nextWrite, firstOverwriteCandidate));
137  continue;
138  }
139  }
140  if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
141  // Don't need to consider disjoint accesses.
143  cast<VectorTransferOpInterface>(write.getOperation()),
144  cast<VectorTransferOpInterface>(transferOp.getOperation())))
145  continue;
146  }
147  blockingAccesses.push_back(user);
148  }
149  if (firstOverwriteCandidate == nullptr)
150  return;
151  Region *topRegion = firstOverwriteCandidate->getParentRegion();
152  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
153  assert(writeAncestor &&
154  "write op should be recursively part of the top region");
155 
156  for (Operation *access : blockingAccesses) {
157  Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
158  // TODO: if the access and write have the same ancestor we could recurse in
159  // the region to know if the access is reachable with more precision.
160  if (accessAncestor == nullptr ||
161  !isReachable(writeAncestor, accessAncestor))
162  continue;
163  if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
164  LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
165  << *accessAncestor << "\n");
166  return;
167  }
168  }
169  LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
170  << " overwritten by: " << *firstOverwriteCandidate << "\n");
171  opToErase.push_back(write.getOperation());
172 }
173 
174 /// A transfer_write candidate to storeToLoad forwarding must:
175 /// 1. Access the same memref with the same indices and vector type as the
176 /// transfer_read.
177 /// 2. Dominate the transfer_read operation.
178 /// If several candidates are available, one must be dominated by all the others
179 /// since they are all dominating the same transfer_read. We only consider the
180 /// transfer_write dominated by all the other candidates as this will be the
181 /// last transfer_write executed before the transfer_read.
182 /// If we found such a candidate we can do the forwarding if all the other
183 /// potentially aliasing ops that may reach the transfer_read are post-dominated
184 /// by the transfer_write.
185 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
186  if (read.hasOutOfBoundsDim())
187  return;
188  LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
189  << "\n");
190  SmallVector<Operation *, 8> blockingWrites;
191  vector::TransferWriteOp lastwrite = nullptr;
192  Value source = read.getSource();
193  // Skip subview ops.
194  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
195  source = subView.getSource();
196  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
197  source.getUsers().end());
198  llvm::SmallDenseSet<Operation *, 32> processed;
199  while (!users.empty()) {
200  Operation *user = users.pop_back_val();
201  // If the user has already been processed skip.
202  if (!processed.insert(user).second)
203  continue;
204  if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
205  users.append(subView->getUsers().begin(), subView->getUsers().end());
206  continue;
207  }
208  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
209  continue;
210  if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
211  // If there is a write, but we can prove that it is disjoint we can ignore
212  // the write.
214  cast<VectorTransferOpInterface>(write.getOperation()),
215  cast<VectorTransferOpInterface>(read.getOperation())))
216  continue;
217  if (write.getSource() == read.getSource() &&
218  dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
219  if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
220  lastwrite = write;
221  else
222  assert(dominators.dominates(write, lastwrite));
223  continue;
224  }
225  }
226  blockingWrites.push_back(user);
227  }
228 
229  if (lastwrite == nullptr)
230  return;
231 
232  Region *topRegion = lastwrite->getParentRegion();
233  Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
234  assert(readAncestor &&
235  "read op should be recursively part of the top region");
236 
237  for (Operation *write : blockingWrites) {
238  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
239  // TODO: if the store and read have the same ancestor we could recurse in
240  // the region to know if the read is reachable with more precision.
241  if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
242  continue;
243  if (!postDominators.postDominates(lastwrite, write)) {
244  LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
245  << *write << "\n");
246  return;
247  }
248  }
249 
250  LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
251  << " to: " << *read.getOperation() << "\n");
252  read.replaceAllUsesWith(lastwrite.getVector());
253  opToErase.push_back(read.getOperation());
254 }
255 
256 /// Drops unit dimensions from the input MemRefType.
257 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
258  ArrayRef<int64_t> sizes,
259  ArrayRef<int64_t> strides) {
260  SmallVector<int64_t> targetShape = llvm::to_vector(
261  llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
262  Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
263  targetShape, inputType, offsets, sizes, strides);
264  return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
265 }
266 
267 /// Creates a rank-reducing memref.subview op that drops unit dims from its
268 /// input. Or just returns the input if it was already without unit dims.
269 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
270  mlir::Location loc,
271  Value input) {
272  MemRefType inputType = input.getType().cast<MemRefType>();
273  assert(inputType.hasStaticShape());
274  SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
275  SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
276  ArrayRef<int64_t> subViewSizes = inputType.getShape();
277  MemRefType resultType =
278  dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
279  if (canonicalizeStridedLayout(resultType) ==
280  canonicalizeStridedLayout(inputType))
281  return input;
282  return rewriter.create<memref::SubViewOp>(
283  loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
284 }
285 
286 /// Returns the number of dims that aren't unit dims.
287 static int getReducedRank(ArrayRef<int64_t> shape) {
288  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
289 }
290 
291 /// Returns true if all values are `arith.constant 0 : index`
292 static bool isZero(Value v) {
293  auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
294  return cst && cst.value() == 0;
295 }
296 
297 /// Rewrites vector.transfer_read ops where the source has unit dims, by
298 /// inserting a memref.subview dropping those unit dims.
299 class TransferReadDropUnitDimsPattern
300  : public OpRewritePattern<vector::TransferReadOp> {
302 
303  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
304  PatternRewriter &rewriter) const override {
305  auto loc = transferReadOp.getLoc();
306  Value vector = transferReadOp.getVector();
307  VectorType vectorType = vector.getType().cast<VectorType>();
308  Value source = transferReadOp.getSource();
309  MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
310  // TODO: support tensor types.
311  if (!sourceType || !sourceType.hasStaticShape())
312  return failure();
313  if (sourceType.getNumElements() != vectorType.getNumElements())
314  return failure();
315  // TODO: generalize this pattern, relax the requirements here.
316  if (transferReadOp.hasOutOfBoundsDim())
317  return failure();
318  if (!transferReadOp.getPermutationMap().isMinorIdentity())
319  return failure();
320  int reducedRank = getReducedRank(sourceType.getShape());
321  if (reducedRank == sourceType.getRank())
322  return failure(); // The source shape can't be further reduced.
323  if (reducedRank != vectorType.getRank())
324  return failure(); // This pattern requires the vector shape to match the
325  // reduced source shape.
326  if (llvm::any_of(transferReadOp.getIndices(),
327  [](Value v) { return !isZero(v); }))
328  return failure();
329  Value reducedShapeSource =
330  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
331  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
332  SmallVector<Value> zeros(reducedRank, c0);
333  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
334  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
335  transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
336  return success();
337  }
338 };
339 
340 /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
341 /// unit dims, by inserting a memref.subview dropping those unit dims.
342 class TransferWriteDropUnitDimsPattern
343  : public OpRewritePattern<vector::TransferWriteOp> {
345 
346  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
347  PatternRewriter &rewriter) const override {
348  auto loc = transferWriteOp.getLoc();
349  Value vector = transferWriteOp.getVector();
350  VectorType vectorType = vector.getType().cast<VectorType>();
351  Value source = transferWriteOp.getSource();
352  MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
353  // TODO: support tensor type.
354  if (!sourceType || !sourceType.hasStaticShape())
355  return failure();
356  if (sourceType.getNumElements() != vectorType.getNumElements())
357  return failure();
358  // TODO: generalize this pattern, relax the requirements here.
359  if (transferWriteOp.hasOutOfBoundsDim())
360  return failure();
361  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
362  return failure();
363  int reducedRank = getReducedRank(sourceType.getShape());
364  if (reducedRank == sourceType.getRank())
365  return failure(); // The source shape can't be further reduced.
366  if (reducedRank != vectorType.getRank())
367  return failure(); // This pattern requires the vector shape to match the
368  // reduced source shape.
369  if (llvm::any_of(transferWriteOp.getIndices(),
370  [](Value v) { return !isZero(v); }))
371  return failure();
372  Value reducedShapeSource =
373  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
374  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
375  SmallVector<Value> zeros(reducedRank, c0);
376  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
377  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
378  transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
379  return success();
380  }
381 };
382 
383 /// Return true if the memref type has its inner dimension matching the given
384 /// shape. Otherwise return false.
385 static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
386  ArrayRef<int64_t> targetShape) {
387  auto shape = memrefType.getShape();
388  SmallVector<int64_t> strides;
389  int64_t offset;
390  if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
391  return false;
392  if (strides.back() != 1)
393  return false;
394  strides.pop_back();
395  int64_t flatDim = 1;
396  for (auto [targetDim, memrefDim, memrefStride] :
397  llvm::reverse(llvm::zip(targetShape, shape, strides))) {
398  flatDim *= memrefDim;
399  if (flatDim != memrefStride || targetDim != memrefDim)
400  return false;
401  }
402  return true;
403 }
404 
405 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
406 /// input starting at `firstDimToCollapse`.
407 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
408  Value input, int64_t firstDimToCollapse) {
409  ShapedType inputType = input.getType().cast<ShapedType>();
410  if (inputType.getRank() == 1)
411  return input;
412  SmallVector<ReassociationIndices> reassociation;
413  for (int64_t i = 0; i < firstDimToCollapse; ++i)
414  reassociation.push_back(ReassociationIndices{i});
415  ReassociationIndices collapsedIndices;
416  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
417  collapsedIndices.push_back(i);
418  reassociation.push_back(collapsedIndices);
419  return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
420 }
421 
422 /// Checks that the indices corresponding to dimensions starting at
423 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
424 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
425 static LogicalResult
426 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
427  SmallVector<Value> &outIndices) {
428  int64_t rank = indices.size();
429  if (firstDimToCollapse >= rank)
430  return failure();
431  for (int64_t i = firstDimToCollapse; i < rank; ++i) {
432  arith::ConstantIndexOp cst =
433  indices[i].getDefiningOp<arith::ConstantIndexOp>();
434  if (!cst || cst.value() != 0)
435  return failure();
436  }
437  outIndices = indices;
438  outIndices.resize(firstDimToCollapse + 1);
439  return success();
440 }
441 
442 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
443 /// memref.collapse_shape on the source so that the resulting
444 /// vector.transfer_read has a 1D source. Requires the source shape to be
445 /// already reduced i.e. without unit dims.
446 class FlattenContiguousRowMajorTransferReadPattern
447  : public OpRewritePattern<vector::TransferReadOp> {
449 
450  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
451  PatternRewriter &rewriter) const override {
452  auto loc = transferReadOp.getLoc();
453  Value vector = transferReadOp.getVector();
454  VectorType vectorType = vector.getType().cast<VectorType>();
455  Value source = transferReadOp.getSource();
456  MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
457  // Contiguity check is valid on tensors only.
458  if (!sourceType)
459  return failure();
460  if (vectorType.getRank() <= 1)
461  // Already 0D/1D, nothing to do.
462  return failure();
463  if (!hasMatchingInnerContigousShape(
464  sourceType,
465  vectorType.getShape().take_back(vectorType.getRank() - 1)))
466  return failure();
467  int64_t firstContiguousInnerDim =
468  sourceType.getRank() - vectorType.getRank();
469  // TODO: generalize this pattern, relax the requirements here.
470  if (transferReadOp.hasOutOfBoundsDim())
471  return failure();
472  if (!transferReadOp.getPermutationMap().isMinorIdentity())
473  return failure();
474  if (transferReadOp.getMask())
475  return failure();
476  SmallVector<Value> collapsedIndices;
477  if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
478  firstContiguousInnerDim,
479  collapsedIndices)))
480  return failure();
481  Value collapsedSource =
482  collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
483  MemRefType collapsedSourceType =
484  collapsedSource.getType().dyn_cast<MemRefType>();
485  int64_t collapsedRank = collapsedSourceType.getRank();
486  assert(collapsedRank == firstContiguousInnerDim + 1);
488  getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
489  auto collapsedMap =
490  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
491  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
492  vectorType.getElementType());
493  vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
494  loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
495  flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
496  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
497  transferReadOp, vector.getType().cast<VectorType>(), flatRead);
498  return success();
499  }
500 };
501 
502 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
503 /// memref.collapse_shape on the source so that the resulting
504 /// vector.transfer_write has a 1D source. Requires the source shape to be
505 /// already reduced i.e. without unit dims.
506 class FlattenContiguousRowMajorTransferWritePattern
507  : public OpRewritePattern<vector::TransferWriteOp> {
509 
510  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
511  PatternRewriter &rewriter) const override {
512  auto loc = transferWriteOp.getLoc();
513  Value vector = transferWriteOp.getVector();
514  VectorType vectorType = vector.getType().cast<VectorType>();
515  Value source = transferWriteOp.getSource();
516  MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
517  // Contiguity check is valid on tensors only.
518  if (!sourceType)
519  return failure();
520  if (vectorType.getRank() <= 1)
521  // Already 0D/1D, nothing to do.
522  return failure();
523  if (!hasMatchingInnerContigousShape(
524  sourceType,
525  vectorType.getShape().take_back(vectorType.getRank() - 1)))
526  return failure();
527  int64_t firstContiguousInnerDim =
528  sourceType.getRank() - vectorType.getRank();
529  // TODO: generalize this pattern, relax the requirements here.
530  if (transferWriteOp.hasOutOfBoundsDim())
531  return failure();
532  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
533  return failure();
534  if (transferWriteOp.getMask())
535  return failure();
536  SmallVector<Value> collapsedIndices;
537  if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
538  firstContiguousInnerDim,
539  collapsedIndices)))
540  return failure();
541  Value collapsedSource =
542  collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
543  MemRefType collapsedSourceType =
544  collapsedSource.getType().cast<MemRefType>();
545  int64_t collapsedRank = collapsedSourceType.getRank();
546  assert(collapsedRank == firstContiguousInnerDim + 1);
548  getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
549  auto collapsedMap =
550  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
551  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
552  vectorType.getElementType());
553  Value flatVector =
554  rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
555  vector::TransferWriteOp flatWrite =
556  rewriter.create<vector::TransferWriteOp>(
557  loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
558  flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
559  rewriter.eraseOp(transferWriteOp);
560  return success();
561  }
562 };
563 
564 /// Rewrite extractelement(transfer_read) to memref.load.
565 ///
566 /// Rewrite only if the extractelement op is the single user of the transfer op.
567 /// E.g., do not rewrite IR such as:
568 /// %0 = vector.transfer_read ... : vector<1024xf32>
569 /// %1 = vector.extractelement %0[%a : index] : vector<1024xf32>
570 /// %2 = vector.extractelement %0[%b : index] : vector<1024xf32>
571 /// Rewriting such IR (replacing one vector load with multiple scalar loads) may
572 /// negatively affect performance.
573 class RewriteScalarExtractElementOfTransferRead
574  : public OpRewritePattern<vector::ExtractElementOp> {
576 
577  LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
578  PatternRewriter &rewriter) const override {
579  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
580  if (!xferOp)
581  return failure();
582  // xfer result must have a single use. Otherwise, it may be better to
583  // perform a vector load.
584  if (!extractOp.getVector().hasOneUse())
585  return failure();
586  // Mask not supported.
587  if (xferOp.getMask())
588  return failure();
589  // Map not supported.
590  if (!xferOp.getPermutationMap().isMinorIdentity())
591  return failure();
592  // Cannot rewrite if the indices may be out of bounds. The starting point is
593  // always inbounds, so we don't care in case of 0d transfers.
594  if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
595  return failure();
596  // Construct scalar load.
597  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
598  xferOp.getIndices().end());
599  if (extractOp.getPosition()) {
600  AffineExpr sym0, sym1;
601  bindSymbols(extractOp.getContext(), sym0, sym1);
603  rewriter, extractOp.getLoc(), sym0 + sym1,
604  {newIndices[newIndices.size() - 1], extractOp.getPosition()});
605  if (ofr.is<Value>()) {
606  newIndices[newIndices.size() - 1] = ofr.get<Value>();
607  } else {
608  newIndices[newIndices.size() - 1] =
609  rewriter.create<arith::ConstantIndexOp>(extractOp.getLoc(),
610  *getConstantIntValue(ofr));
611  }
612  }
613  if (xferOp.getSource().getType().isa<MemRefType>()) {
614  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
615  newIndices);
616  } else {
617  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
618  extractOp, xferOp.getSource(), newIndices);
619  }
620  return success();
621  }
622 };
623 
624 /// Rewrite extract(transfer_read) to memref.load.
625 ///
626 /// Rewrite only if the extractelement op is the single user of the transfer op.
627 /// E.g., do not rewrite IR such as:
628 /// %0 = vector.transfer_read ... : vector<1024xf32>
629 /// %1 = vector.extract %0[0] : vector<1024xf32>
630 /// %2 = vector.extract %0[5] : vector<1024xf32>
631 /// Rewriting such IR (replacing one vector load with multiple scalar loads) may
632 /// negatively affect performance.
633 class RewriteScalarExtractOfTransferRead
634  : public OpRewritePattern<vector::ExtractOp> {
636 
637  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
638  PatternRewriter &rewriter) const override {
639  // Only match scalar extracts.
640  if (extractOp.getType().isa<VectorType>())
641  return failure();
642  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
643  if (!xferOp)
644  return failure();
645  // xfer result must have a single use. Otherwise, it may be better to
646  // perform a vector load.
647  if (!extractOp.getVector().hasOneUse())
648  return failure();
649  // Mask not supported.
650  if (xferOp.getMask())
651  return failure();
652  // Map not supported.
653  if (!xferOp.getPermutationMap().isMinorIdentity())
654  return failure();
655  // Cannot rewrite if the indices may be out of bounds. The starting point is
656  // always inbounds, so we don't care in case of 0d transfers.
657  if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
658  return failure();
659  // Construct scalar load.
660  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
661  xferOp.getIndices().end());
662  for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
663  int64_t offset = it.value().cast<IntegerAttr>().getInt();
664  int64_t idx =
665  newIndices.size() - extractOp.getPosition().size() + it.index();
667  rewriter, extractOp.getLoc(),
668  rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
669  if (ofr.is<Value>()) {
670  newIndices[idx] = ofr.get<Value>();
671  } else {
672  newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
673  extractOp.getLoc(), *getConstantIntValue(ofr));
674  }
675  }
676  if (xferOp.getSource().getType().isa<MemRefType>()) {
677  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
678  newIndices);
679  } else {
680  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
681  extractOp, xferOp.getSource(), newIndices);
682  }
683  return success();
684  }
685 };
686 
687 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
688 /// to memref.store.
689 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
691 
692  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
693  PatternRewriter &rewriter) const override {
694  // Must be a scalar write.
695  auto vecType = xferOp.getVectorType();
696  if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
697  return failure();
698  // Mask not supported.
699  if (xferOp.getMask())
700  return failure();
701  // Map not supported.
702  if (!xferOp.getPermutationMap().isMinorIdentity())
703  return failure();
704  // Only float and integer element types are supported.
705  Value scalar;
706  if (vecType.getRank() == 0) {
707  // vector.extract does not support vector<f32> etc., so use
708  // vector.extractelement instead.
709  scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
710  xferOp.getVector());
711  } else {
712  SmallVector<int64_t> pos(vecType.getRank(), 0);
713  scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
714  xferOp.getVector(), pos);
715  }
716  // Construct a scalar store.
717  if (xferOp.getSource().getType().isa<MemRefType>()) {
718  rewriter.replaceOpWithNewOp<memref::StoreOp>(
719  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
720  } else {
721  rewriter.replaceOpWithNewOp<tensor::InsertOp>(
722  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
723  }
724  return success();
725  }
726 };
727 } // namespace
728 
730  Operation *rootOp) {
731  TransferOptimization opt(rewriter, rootOp);
732  // Run store to load forwarding first since it can expose more dead store
733  // opportunity.
734  rootOp->walk([&](vector::TransferReadOp read) {
735  if (read.getShapedType().isa<MemRefType>())
736  opt.storeToLoadForwarding(read);
737  });
738  opt.removeDeadOp();
739  rootOp->walk([&](vector::TransferWriteOp write) {
740  if (write.getShapedType().isa<MemRefType>())
741  opt.deadStoreOp(write);
742  });
743  opt.removeDeadOp();
744 }
745 
747  RewritePatternSet &patterns, PatternBenefit benefit) {
748  patterns.add<RewriteScalarExtractElementOfTransferRead,
749  RewriteScalarExtractOfTransferRead, RewriteScalarWrite>(
750  patterns.getContext(), benefit);
751 }
752 
754  RewritePatternSet &patterns, PatternBenefit benefit) {
755  patterns
756  .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
757  patterns.getContext(), benefit);
759 }
760 
762  RewritePatternSet &patterns, PatternBenefit benefit) {
763  patterns.add<FlattenContiguousRowMajorTransferReadPattern,
764  FlattenContiguousRowMajorTransferWritePattern>(
765  patterns.getContext(), benefit);
766  populateShapeCastFoldingPatterns(patterns, benefit);
767 }
static bool isZero(OpFoldResult v)
Definition: Tiling.cpp:47
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
#define DBGS()
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:30
succ_iterator succ_end()
Definition: Block.h:252
succ_iterator succ_begin()
Definition: Block.h:251
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:362
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:343
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:263
A class for computing basic dominance information.
Definition: Dominance.h:121
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
This class represents a single result from folding an operation.
Definition: OpDefinition.h:235
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:640
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:218
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:197
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:214
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
A class for computing basic postdominance information.
Definition: Dominance.h:180
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:482
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
U dyn_cast() const
Definition: Types.h:311
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
user_range getUsers() const
Definition: Value.h:217
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Return true if we can prove that the transfer operations access disjoint memory.
Definition: VectorOps.cpp:209
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1219
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:343
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:328