MLIR  18.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Dominance.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Debug.h"
28 
29 #define DEBUG_TYPE "vector-transfer-opt"
30 
31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
32 
33 using namespace mlir;
34 
35 /// Return the ancestor op in the region or nullptr if the region is not
36 /// an ancestor of the op.
38  for (; op != nullptr && op->getParentRegion() != region;
39  op = op->getParentOp())
40  ;
41  return op;
42 }
43 
44 namespace {
45 
46 class TransferOptimization {
47 public:
48  TransferOptimization(RewriterBase &rewriter, Operation *op)
49  : rewriter(rewriter), dominators(op), postDominators(op) {}
50  void deadStoreOp(vector::TransferWriteOp);
51  void storeToLoadForwarding(vector::TransferReadOp);
52  void removeDeadOp() {
53  for (Operation *op : opToErase)
54  rewriter.eraseOp(op);
55  opToErase.clear();
56  }
57 
58 private:
59  RewriterBase &rewriter;
60  bool isReachable(Operation *start, Operation *dest);
61  DominanceInfo dominators;
62  PostDominanceInfo postDominators;
63  std::vector<Operation *> opToErase;
64 };
65 
66 } // namespace
67 /// Return true if there is a path from start operation to dest operation,
68 /// otherwise return false. The operations have to be in the same region.
70  assert(start->getParentRegion() == dest->getParentRegion() &&
71  "This function only works for ops i the same region");
72  // Simple case where the start op dominate the destination.
73  if (dominators.dominates(start, dest))
74  return true;
75  Block *startBlock = start->getBlock();
76  Block *destBlock = dest->getBlock();
77  SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
78  startBlock->succ_end());
80  while (!worklist.empty()) {
81  Block *bb = worklist.pop_back_val();
82  if (!visited.insert(bb).second)
83  continue;
84  if (dominators.dominates(bb, destBlock))
85  return true;
86  worklist.append(bb->succ_begin(), bb->succ_end());
87  }
88  return false;
89 }
90 
91 /// For transfer_write to overwrite fully another transfer_write must:
92 /// 1. Access the same memref with the same indices and vector type.
93 /// 2. Post-dominate the other transfer_write operation.
94 /// If several candidates are available, one must be post-dominated by all the
95 /// others since they are all post-dominating the same transfer_write. We only
96 /// consider the transfer_write post-dominated by all the other candidates as
97 /// this will be the first transfer_write executed after the potentially dead
98 /// transfer_write.
99 /// If we found such an overwriting transfer_write we know that the original
100 /// transfer_write is dead if all reads that can be reached from the potentially
101 /// dead transfer_write are dominated by the overwriting transfer_write.
102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
103  LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
104  << "\n");
105  llvm::SmallVector<Operation *, 8> blockingAccesses;
106  Operation *firstOverwriteCandidate = nullptr;
107  Value source = write.getSource();
108  // Skip subview ops.
109  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
110  source = subView.getSource();
111  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
112  source.getUsers().end());
113  llvm::SmallDenseSet<Operation *, 32> processed;
114  while (!users.empty()) {
115  Operation *user = users.pop_back_val();
116  // If the user has already been processed skip.
117  if (!processed.insert(user).second)
118  continue;
119  if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
120  users.append(subView->getUsers().begin(), subView->getUsers().end());
121  continue;
122  }
123  if (isMemoryEffectFree(user))
124  continue;
125  if (user == write.getOperation())
126  continue;
127  if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
128  // Check candidate that can override the store.
129  if (write.getSource() == nextWrite.getSource() &&
130  checkSameValueWAW(nextWrite, write) &&
131  postDominators.postDominates(nextWrite, write)) {
132  if (firstOverwriteCandidate == nullptr ||
133  postDominators.postDominates(firstOverwriteCandidate, nextWrite))
134  firstOverwriteCandidate = nextWrite;
135  else
136  assert(
137  postDominators.postDominates(nextWrite, firstOverwriteCandidate));
138  continue;
139  }
140  }
141  if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
142  // Don't need to consider disjoint accesses.
144  cast<VectorTransferOpInterface>(write.getOperation()),
145  cast<VectorTransferOpInterface>(transferOp.getOperation()),
146  /*testDynamicValueUsingBounds=*/true))
147  continue;
148  }
149  blockingAccesses.push_back(user);
150  }
151  if (firstOverwriteCandidate == nullptr)
152  return;
153  Region *topRegion = firstOverwriteCandidate->getParentRegion();
154  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
155  assert(writeAncestor &&
156  "write op should be recursively part of the top region");
157 
158  for (Operation *access : blockingAccesses) {
159  Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
160  // TODO: if the access and write have the same ancestor we could recurse in
161  // the region to know if the access is reachable with more precision.
162  if (accessAncestor == nullptr ||
163  !isReachable(writeAncestor, accessAncestor))
164  continue;
165  if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
166  LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
167  << *accessAncestor << "\n");
168  return;
169  }
170  }
171  LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
172  << " overwritten by: " << *firstOverwriteCandidate << "\n");
173  opToErase.push_back(write.getOperation());
174 }
175 
176 /// A transfer_write candidate to storeToLoad forwarding must:
177 /// 1. Access the same memref with the same indices and vector type as the
178 /// transfer_read.
179 /// 2. Dominate the transfer_read operation.
180 /// If several candidates are available, one must be dominated by all the others
181 /// since they are all dominating the same transfer_read. We only consider the
182 /// transfer_write dominated by all the other candidates as this will be the
183 /// last transfer_write executed before the transfer_read.
184 /// If we found such a candidate we can do the forwarding if all the other
185 /// potentially aliasing ops that may reach the transfer_read are post-dominated
186 /// by the transfer_write.
187 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
188  if (read.hasOutOfBoundsDim())
189  return;
190  LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
191  << "\n");
192  SmallVector<Operation *, 8> blockingWrites;
193  vector::TransferWriteOp lastwrite = nullptr;
194  Value source = read.getSource();
195  // Skip subview ops.
196  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
197  source = subView.getSource();
198  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
199  source.getUsers().end());
200  llvm::SmallDenseSet<Operation *, 32> processed;
201  while (!users.empty()) {
202  Operation *user = users.pop_back_val();
203  // If the user has already been processed skip.
204  if (!processed.insert(user).second)
205  continue;
206  if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
207  users.append(subView->getUsers().begin(), subView->getUsers().end());
208  continue;
209  }
210  if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
211  users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
212  continue;
213  }
214  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
215  continue;
216  if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
217  // If there is a write, but we can prove that it is disjoint we can ignore
218  // the write.
220  cast<VectorTransferOpInterface>(write.getOperation()),
221  cast<VectorTransferOpInterface>(read.getOperation()),
222  /*testDynamicValueUsingBounds=*/true))
223  continue;
224  if (write.getSource() == read.getSource() &&
225  dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
226  if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
227  lastwrite = write;
228  else
229  assert(dominators.dominates(write, lastwrite));
230  continue;
231  }
232  }
233  blockingWrites.push_back(user);
234  }
235 
236  if (lastwrite == nullptr)
237  return;
238 
239  Region *topRegion = lastwrite->getParentRegion();
240  Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
241  assert(readAncestor &&
242  "read op should be recursively part of the top region");
243 
244  for (Operation *write : blockingWrites) {
245  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
246  // TODO: if the store and read have the same ancestor we could recurse in
247  // the region to know if the read is reachable with more precision.
248  if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
249  continue;
250  if (!postDominators.postDominates(lastwrite, write)) {
251  LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
252  << *write << "\n");
253  return;
254  }
255  }
256 
257  LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
258  << " to: " << *read.getOperation() << "\n");
259  read.replaceAllUsesWith(lastwrite.getVector());
260  opToErase.push_back(read.getOperation());
261 }
262 
263 /// Returns a copy of `shape` without unit dims.
265  SmallVector<int64_t> reducedShape;
266  llvm::copy_if(shape, std::back_inserter(reducedShape),
267  [](int64_t dimSize) { return dimSize != 1; });
268  return reducedShape;
269 }
270 
271 /// Converts OpFoldResults to int64_t shape without unit dims.
273  SmallVector<int64_t> reducedShape;
274  for (const auto size : mixedSizes) {
275  if (llvm::dyn_cast_if_present<Value>(size)) {
276  reducedShape.push_back(ShapedType::kDynamic);
277  continue;
278  }
279 
280  auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
281  if (value == 1)
282  continue;
283  reducedShape.push_back(value.getSExtValue());
284  }
285  return reducedShape;
286 }
287 
288 /// Drops unit dimensions from the input MemRefType.
289 static MemRefType dropUnitDims(MemRefType inputType,
290  ArrayRef<OpFoldResult> offsets,
292  ArrayRef<OpFoldResult> strides) {
293  auto targetShape = getReducedShape(sizes);
294  Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
295  targetShape, inputType, offsets, sizes, strides);
296  return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
297 }
298 
299 /// Creates a rank-reducing memref.subview op that drops unit dims from its
300 /// input. Or just returns the input if it was already without unit dims.
302  mlir::Location loc,
303  Value input) {
304  MemRefType inputType = cast<MemRefType>(input.getType());
305  SmallVector<OpFoldResult> offsets(inputType.getRank(),
306  rewriter.getIndexAttr(0));
307  SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
308  SmallVector<OpFoldResult> strides(inputType.getRank(),
309  rewriter.getIndexAttr(1));
310  MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
311 
312  if (canonicalizeStridedLayout(resultType) ==
313  canonicalizeStridedLayout(inputType))
314  return input;
315  return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
316  sizes, strides);
317 }
318 
319 /// Returns the number of dims that aren't unit dims.
321  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
322 }
323 
324 /// Trims non-scalable one dimensions from `oldType` and returns the result
325 /// type.
326 static VectorType trimNonScalableUnitDims(VectorType oldType) {
327  SmallVector<int64_t> newShape;
328  SmallVector<bool> newScalableDims;
329  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
330  if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
331  continue;
332  newShape.push_back(dimSize);
333  newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
334  }
335  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
336 }
337 
338 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
339 static FailureOr<Value>
341  vector::CreateMaskOp op) {
342  auto type = op.getType();
343  auto reducedType = trimNonScalableUnitDims(type);
344  if (reducedType.getRank() == type.getRank())
345  return failure();
346 
347  SmallVector<Value> reducedOperands;
348  for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
349  type.getShape(), type.getScalableDims(), op.getOperands())) {
350  if (dim == 1 && !dimIsScalable) {
351  // If the mask for the unit dim is not a constant of 1, do nothing.
352  auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
353  if (!constant || (constant.value() != 1))
354  return failure();
355  continue;
356  }
357  reducedOperands.push_back(operand);
358  }
359  return rewriter
360  .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
361  .getResult();
362 }
363 
364 namespace {
365 
366 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
367 /// inserting a memref.subview dropping those unit dims. The vector shapes are
368 /// also reduced accordingly.
369 class TransferReadDropUnitDimsPattern
370  : public OpRewritePattern<vector::TransferReadOp> {
372 
373  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
374  PatternRewriter &rewriter) const override {
375  auto loc = transferReadOp.getLoc();
376  Value vector = transferReadOp.getVector();
377  VectorType vectorType = cast<VectorType>(vector.getType());
378  Value source = transferReadOp.getSource();
379  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
380  // TODO: support tensor types.
381  if (!sourceType)
382  return failure();
383  // TODO: generalize this pattern, relax the requirements here.
384  if (transferReadOp.hasOutOfBoundsDim())
385  return failure();
386  if (!transferReadOp.getPermutationMap().isMinorIdentity())
387  return failure();
388  // Check if the source shape can be further reduced.
389  int reducedRank = getReducedRank(sourceType.getShape());
390  if (reducedRank == sourceType.getRank())
391  return failure();
392  // Check if the reduced vector shape matches the reduced source shape.
393  // Otherwise, this case is not supported yet.
394  auto reducedVectorType = trimNonScalableUnitDims(vectorType);
395  if (reducedRank != reducedVectorType.getRank())
396  return failure();
397  if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
398  return getConstantIntValue(v) != static_cast<int64_t>(0);
399  }))
400  return failure();
401 
402  Value maskOp = transferReadOp.getMask();
403  if (maskOp) {
404  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
405  if (!createMaskOp)
406  return rewriter.notifyMatchFailure(
407  transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
408  "currently supported");
409  FailureOr<Value> rankReducedCreateMask =
410  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
411  if (failed(rankReducedCreateMask))
412  return failure();
413  maskOp = *rankReducedCreateMask;
414  }
415 
416  Value reducedShapeSource =
417  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
418  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
419  SmallVector<Value> zeros(reducedRank, c0);
420  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
421  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
422  auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
423  loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
424  transferReadOp.getPadding(), maskOp,
425  rewriter.getBoolArrayAttr(inBounds));
426  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
427  loc, vectorType, newTransferReadOp);
428  rewriter.replaceOp(transferReadOp, shapeCast);
429 
430  return success();
431  }
432 };
433 
434 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
435 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
436 /// vector shapes are also reduced accordingly.
437 class TransferWriteDropUnitDimsPattern
438  : public OpRewritePattern<vector::TransferWriteOp> {
440 
441  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
442  PatternRewriter &rewriter) const override {
443  auto loc = transferWriteOp.getLoc();
444  Value vector = transferWriteOp.getVector();
445  VectorType vectorType = cast<VectorType>(vector.getType());
446  Value source = transferWriteOp.getSource();
447  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
448  // TODO: support tensor type.
449  if (!sourceType || !sourceType.hasStaticShape())
450  return failure();
451  if (sourceType.getNumElements() != vectorType.getNumElements())
452  return failure();
453  // TODO: generalize this pattern, relax the requirements here.
454  if (transferWriteOp.hasOutOfBoundsDim())
455  return failure();
456  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
457  return failure();
458  // Check if the destination shape can be further reduced.
459  int reducedRank = getReducedRank(sourceType.getShape());
460  if (reducedRank == sourceType.getRank())
461  return failure();
462  // Check if the reduced vector shape matches the reduced destination shape.
463  // Otherwise, this case is not supported yet.
464  int vectorReducedRank = getReducedRank(vectorType.getShape());
465  if (reducedRank != vectorReducedRank)
466  return failure();
467  if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
468  return getConstantIntValue(v) != static_cast<int64_t>(0);
469  }))
470  return failure();
471  Value reducedShapeSource =
472  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
473  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
474  SmallVector<Value> zeros(reducedRank, c0);
475  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
476  VectorType reducedVectorType = VectorType::get(
477  getReducedShape(vectorType.getShape()), vectorType.getElementType());
478 
479  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
480  loc, reducedVectorType, vector);
481  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
482  transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
483 
484  return success();
485  }
486 };
487 
488 } // namespace
489 
490 /// Return true if the memref type has its inner dimension matching the given
491 /// shape. Otherwise return false.
492 static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
493  ArrayRef<int64_t> targetShape) {
494  auto shape = memrefType.getShape();
495  SmallVector<int64_t> strides;
496  int64_t offset;
497  if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
498  return false;
499  if (strides.back() != 1)
500  return false;
501  strides.pop_back();
502  int64_t flatDim = 1;
503  for (auto [targetDim, memrefDim, memrefStride] :
504  llvm::reverse(llvm::zip(targetShape, shape, strides))) {
505  flatDim *= memrefDim;
506  if (flatDim != memrefStride || targetDim != memrefDim)
507  return false;
508  }
509  return true;
510 }
511 
512 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
513 /// input starting at `firstDimToCollapse`.
515  Value input, int64_t firstDimToCollapse) {
516  ShapedType inputType = cast<ShapedType>(input.getType());
517  if (inputType.getRank() == 1)
518  return input;
519  SmallVector<ReassociationIndices> reassociation;
520  for (int64_t i = 0; i < firstDimToCollapse; ++i)
521  reassociation.push_back(ReassociationIndices{i});
522  ReassociationIndices collapsedIndices;
523  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
524  collapsedIndices.push_back(i);
525  reassociation.push_back(collapsedIndices);
526  return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
527 }
528 
529 /// Checks that the indices corresponding to dimensions starting at
530 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
531 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
532 static LogicalResult
533 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
534  SmallVector<Value> &outIndices) {
535  int64_t rank = indices.size();
536  if (firstDimToCollapse >= rank)
537  return failure();
538  for (int64_t i = firstDimToCollapse; i < rank; ++i) {
539  std::optional<int64_t> cst = getConstantIntValue(indices[i]);
540  if (!cst || cst.value() != 0)
541  return failure();
542  }
543  outIndices = indices;
544  outIndices.resize(firstDimToCollapse + 1);
545  return success();
546 }
547 
548 namespace {
549 
550 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
551 /// memref.collapse_shape on the source so that the resulting
552 /// vector.transfer_read has a 1D source. Requires the source shape to be
553 /// already reduced i.e. without unit dims.
554 class FlattenContiguousRowMajorTransferReadPattern
555  : public OpRewritePattern<vector::TransferReadOp> {
557 
558  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
559  PatternRewriter &rewriter) const override {
560  auto loc = transferReadOp.getLoc();
561  Value vector = transferReadOp.getVector();
562  VectorType vectorType = cast<VectorType>(vector.getType());
563  Value source = transferReadOp.getSource();
564  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
565  // Contiguity check is valid on tensors only.
566  if (!sourceType)
567  return failure();
568  if (vectorType.getRank() <= 1)
569  // Already 0D/1D, nothing to do.
570  return failure();
572  sourceType,
573  vectorType.getShape().take_back(vectorType.getRank() - 1)))
574  return failure();
575  int64_t firstContiguousInnerDim =
576  sourceType.getRank() - vectorType.getRank();
577  // TODO: generalize this pattern, relax the requirements here.
578  if (transferReadOp.hasOutOfBoundsDim())
579  return failure();
580  if (!transferReadOp.getPermutationMap().isMinorIdentity())
581  return failure();
582  if (transferReadOp.getMask())
583  return failure();
584  SmallVector<Value> collapsedIndices;
585  if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
586  firstContiguousInnerDim,
587  collapsedIndices)))
588  return failure();
589  Value collapsedSource =
590  collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
591  MemRefType collapsedSourceType =
592  dyn_cast<MemRefType>(collapsedSource.getType());
593  int64_t collapsedRank = collapsedSourceType.getRank();
594  assert(collapsedRank == firstContiguousInnerDim + 1);
596  getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
597  auto collapsedMap =
598  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
599  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
600  vectorType.getElementType());
601  vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
602  loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
603  flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
604  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
605  transferReadOp, cast<VectorType>(vector.getType()), flatRead);
606  return success();
607  }
608 };
609 
610 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
611 /// memref.collapse_shape on the source so that the resulting
612 /// vector.transfer_write has a 1D source. Requires the source shape to be
613 /// already reduced i.e. without unit dims.
614 class FlattenContiguousRowMajorTransferWritePattern
615  : public OpRewritePattern<vector::TransferWriteOp> {
617 
618  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
619  PatternRewriter &rewriter) const override {
620  auto loc = transferWriteOp.getLoc();
621  Value vector = transferWriteOp.getVector();
622  VectorType vectorType = cast<VectorType>(vector.getType());
623  Value source = transferWriteOp.getSource();
624  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
625  // Contiguity check is valid on tensors only.
626  if (!sourceType)
627  return failure();
628  if (vectorType.getRank() <= 1)
629  // Already 0D/1D, nothing to do.
630  return failure();
632  sourceType,
633  vectorType.getShape().take_back(vectorType.getRank() - 1)))
634  return failure();
635  int64_t firstContiguousInnerDim =
636  sourceType.getRank() - vectorType.getRank();
637  // TODO: generalize this pattern, relax the requirements here.
638  if (transferWriteOp.hasOutOfBoundsDim())
639  return failure();
640  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
641  return failure();
642  if (transferWriteOp.getMask())
643  return failure();
644  SmallVector<Value> collapsedIndices;
645  if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
646  firstContiguousInnerDim,
647  collapsedIndices)))
648  return failure();
649  Value collapsedSource =
650  collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
651  MemRefType collapsedSourceType =
652  cast<MemRefType>(collapsedSource.getType());
653  int64_t collapsedRank = collapsedSourceType.getRank();
654  assert(collapsedRank == firstContiguousInnerDim + 1);
656  getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
657  auto collapsedMap =
658  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
659  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
660  vectorType.getElementType());
661  Value flatVector =
662  rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
663  vector::TransferWriteOp flatWrite =
664  rewriter.create<vector::TransferWriteOp>(
665  loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
666  flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
667  rewriter.eraseOp(transferWriteOp);
668  return success();
669  }
670 };
671 
672 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
673 /// to `memref.load` patterns. The `match` method is shared for both
674 /// `vector.extract` and `vector.extract_element`.
675 template <class VectorExtractOp>
676 class RewriteScalarExtractOfTransferReadBase
677  : public OpRewritePattern<VectorExtractOp> {
679 
680 public:
681  RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
682  PatternBenefit benefit,
683  bool allowMultipleUses)
684  : Base::OpRewritePattern(context, benefit),
685  allowMultipleUses(allowMultipleUses) {}
686 
687  LogicalResult match(VectorExtractOp extractOp) const override {
688  auto xferOp =
689  extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
690  if (!xferOp)
691  return failure();
692  // Check that we are extracting a scalar and not a sub-vector.
693  if (isa<VectorType>(extractOp.getResult().getType()))
694  return failure();
695  // If multiple uses are not allowed, check if xfer has a single use.
696  if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
697  return failure();
698  // If multiple uses are allowed, check if all the xfer uses are extract ops.
699  if (allowMultipleUses &&
700  !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
701  return isa<vector::ExtractOp, vector::ExtractElementOp>(
702  use.getOwner());
703  }))
704  return failure();
705  // Mask not supported.
706  if (xferOp.getMask())
707  return failure();
708  // Map not supported.
709  if (!xferOp.getPermutationMap().isMinorIdentity())
710  return failure();
711  // Cannot rewrite if the indices may be out of bounds.
712  if (xferOp.hasOutOfBoundsDim())
713  return failure();
714  return success();
715  }
716 
717 private:
718  bool allowMultipleUses;
719 };
720 
721 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
722 ///
723 /// All the users of the transfer op must be either `vector.extractelement` or
724 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
725 /// transfer ops with any number of users. Otherwise, rewrite only if the
726 /// extract op is the single user of the transfer op. Rewriting a single
727 /// vector load with multiple scalar loads may negatively affect performance.
728 class RewriteScalarExtractElementOfTransferRead
729  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
730  using RewriteScalarExtractOfTransferReadBase::
731  RewriteScalarExtractOfTransferReadBase;
732 
733  void rewrite(vector::ExtractElementOp extractOp,
734  PatternRewriter &rewriter) const override {
735  // Construct scalar load.
736  auto loc = extractOp.getLoc();
737  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
738  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
739  xferOp.getIndices().end());
740  if (extractOp.getPosition()) {
741  AffineExpr sym0, sym1;
742  bindSymbols(extractOp.getContext(), sym0, sym1);
744  rewriter, loc, sym0 + sym1,
745  {newIndices[newIndices.size() - 1], extractOp.getPosition()});
746  if (ofr.is<Value>()) {
747  newIndices[newIndices.size() - 1] = ofr.get<Value>();
748  } else {
749  newIndices[newIndices.size() - 1] =
750  rewriter.create<arith::ConstantIndexOp>(loc,
751  *getConstantIntValue(ofr));
752  }
753  }
754  if (isa<MemRefType>(xferOp.getSource().getType())) {
755  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
756  newIndices);
757  } else {
758  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
759  extractOp, xferOp.getSource(), newIndices);
760  }
761  }
762 };
763 
764 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
765 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
766 ///
767 /// All the users of the transfer op must be either `vector.extractelement` or
768 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
769 /// transfer ops with any number of users. Otherwise, rewrite only if the
770 /// extract op is the single user of the transfer op. Rewriting a single
771 /// vector load with multiple scalar loads may negatively affect performance.
772 class RewriteScalarExtractOfTransferRead
773  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
774  using RewriteScalarExtractOfTransferReadBase::
775  RewriteScalarExtractOfTransferReadBase;
776 
777  void rewrite(vector::ExtractOp extractOp,
778  PatternRewriter &rewriter) const override {
779  // Construct scalar load.
780  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
781  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
782  xferOp.getIndices().end());
783  for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
784  assert(pos.is<Attribute>() && "Unexpected non-constant index");
785  int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
786  int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
788  rewriter, extractOp.getLoc(),
789  rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
790  if (ofr.is<Value>()) {
791  newIndices[idx] = ofr.get<Value>();
792  } else {
793  newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
794  extractOp.getLoc(), *getConstantIntValue(ofr));
795  }
796  }
797  if (isa<MemRefType>(xferOp.getSource().getType())) {
798  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
799  newIndices);
800  } else {
801  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
802  extractOp, xferOp.getSource(), newIndices);
803  }
804  }
805 };
806 
807 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
808 /// to memref.store.
809 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
811 
812  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
813  PatternRewriter &rewriter) const override {
814  // Must be a scalar write.
815  auto vecType = xferOp.getVectorType();
816  if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
817  return failure();
818  // Mask not supported.
819  if (xferOp.getMask())
820  return failure();
821  // Map not supported.
822  if (!xferOp.getPermutationMap().isMinorIdentity())
823  return failure();
824  // Only float and integer element types are supported.
825  Value scalar;
826  if (vecType.getRank() == 0) {
827  // vector.extract does not support vector<f32> etc., so use
828  // vector.extractelement instead.
829  scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
830  xferOp.getVector());
831  } else {
832  SmallVector<int64_t> pos(vecType.getRank(), 0);
833  scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
834  xferOp.getVector(), pos);
835  }
836  // Construct a scalar store.
837  if (isa<MemRefType>(xferOp.getSource().getType())) {
838  rewriter.replaceOpWithNewOp<memref::StoreOp>(
839  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
840  } else {
841  rewriter.replaceOpWithNewOp<tensor::InsertOp>(
842  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
843  }
844  return success();
845  }
846 };
847 
848 } // namespace
849 
851  Operation *rootOp) {
852  TransferOptimization opt(rewriter, rootOp);
853  // Run store to load forwarding first since it can expose more dead store
854  // opportunity.
855  rootOp->walk([&](vector::TransferReadOp read) {
856  if (isa<MemRefType>(read.getShapedType()))
857  opt.storeToLoadForwarding(read);
858  });
859  opt.removeDeadOp();
860  rootOp->walk([&](vector::TransferWriteOp write) {
861  if (isa<MemRefType>(write.getShapedType()))
862  opt.deadStoreOp(write);
863  });
864  opt.removeDeadOp();
865 }
866 
868  RewritePatternSet &patterns, PatternBenefit benefit,
869  bool allowMultipleUses) {
870  patterns.add<RewriteScalarExtractElementOfTransferRead,
871  RewriteScalarExtractOfTransferRead>(patterns.getContext(),
872  benefit, allowMultipleUses);
873  patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
874 }
875 
877  RewritePatternSet &patterns, PatternBenefit benefit) {
878  patterns
879  .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
880  patterns.getContext(), benefit);
882 }
883 
885  RewritePatternSet &patterns, PatternBenefit benefit) {
886  patterns.add<FlattenContiguousRowMajorTransferReadPattern,
887  FlattenContiguousRowMajorTransferWritePattern>(
888  patterns.getContext(), benefit);
889  populateShapeCastFoldingPatterns(patterns, benefit);
890 }
static bool isReachable(Block *from, Block *to, ArrayRef< Block * > except)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input)
Creates a rank-reducing memref.subview op that drops unit dims from its input.
static int getReducedRank(ArrayRef< int64_t > shape)
Returns the number of dims that aren't unit dims.
static VectorType trimNonScalableUnitDims(VectorType oldType)
Trims non-scalable one dimensions from oldType and returns the result type.
static FailureOr< Value > createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, vector::CreateMaskOp op)
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
static SmallVector< int64_t > getReducedShape(ArrayRef< int64_t > shape)
Returns a copy of shape without unit dims.
#define DBGS()
static int64_t hasMatchingInnerContigousShape(MemRefType memrefType, ArrayRef< int64_t > targetShape)
Return true if the memref type has its inner dimension matching the given shape.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
Drops unit dimensions from the input MemRefType.
static LogicalResult checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, SmallVector< Value > &outIndices)
Checks that the indices corresponding to dimensions starting at firstDimToCollapse are constant 0,...
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
succ_iterator succ_end()
Definition: Block.h:259
succ_iterator succ_begin()
Definition: Block.h:258
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:376
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:357
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
A class for computing basic dominance information.
Definition: Dominance.h:121
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
A class for computing basic postdominance information.
Definition: Dominance.h:180
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_range getUsers() const
Definition: Value.h:224
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1172
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:114
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:251
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:348
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:584
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361