MLIR  20.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
24 #include "mlir/IR/Dominance.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Debug.h"
29 
30 #define DEBUG_TYPE "vector-transfer-opt"
31 
32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
33 
34 using namespace mlir;
35 
36 /// Return the ancestor op in the region or nullptr if the region is not
37 /// an ancestor of the op.
39  for (; op != nullptr && op->getParentRegion() != region;
40  op = op->getParentOp())
41  ;
42  return op;
43 }
44 
45 namespace {
46 
47 class TransferOptimization {
48 public:
49  TransferOptimization(RewriterBase &rewriter, Operation *op)
50  : rewriter(rewriter), dominators(op), postDominators(op) {}
51  void deadStoreOp(vector::TransferWriteOp);
52  void storeToLoadForwarding(vector::TransferReadOp);
53  void removeDeadOp() {
54  for (Operation *op : opToErase)
55  rewriter.eraseOp(op);
56  opToErase.clear();
57  }
58 
59 private:
60  RewriterBase &rewriter;
61  bool isReachable(Operation *start, Operation *dest);
62  DominanceInfo dominators;
63  PostDominanceInfo postDominators;
64  std::vector<Operation *> opToErase;
65 };
66 
67 } // namespace
68 /// Return true if there is a path from start operation to dest operation,
69 /// otherwise return false. The operations have to be in the same region.
70 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
71  assert(start->getParentRegion() == dest->getParentRegion() &&
72  "This function only works for ops i the same region");
73  // Simple case where the start op dominate the destination.
74  if (dominators.dominates(start, dest))
75  return true;
76  return start->getBlock()->isReachable(dest->getBlock());
77 }
78 
79 /// For transfer_write to overwrite fully another transfer_write must:
80 /// 1. Access the same memref with the same indices and vector type.
81 /// 2. Post-dominate the other transfer_write operation.
82 /// If several candidates are available, one must be post-dominated by all the
83 /// others since they are all post-dominating the same transfer_write. We only
84 /// consider the transfer_write post-dominated by all the other candidates as
85 /// this will be the first transfer_write executed after the potentially dead
86 /// transfer_write.
87 /// If we found such an overwriting transfer_write we know that the original
88 /// transfer_write is dead if all reads that can be reached from the potentially
89 /// dead transfer_write are dominated by the overwriting transfer_write.
90 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
91  LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
92  << "\n");
93  llvm::SmallVector<Operation *, 8> blockingAccesses;
94  Operation *firstOverwriteCandidate = nullptr;
95  Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource()));
96  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
97  source.getUsers().end());
98  llvm::SmallDenseSet<Operation *, 32> processed;
99  while (!users.empty()) {
100  Operation *user = users.pop_back_val();
101  // If the user has already been processed skip.
102  if (!processed.insert(user).second)
103  continue;
104  if (isa<ViewLikeOpInterface>(user)) {
105  users.append(user->getUsers().begin(), user->getUsers().end());
106  continue;
107  }
108  if (isMemoryEffectFree(user))
109  continue;
110  if (user == write.getOperation())
111  continue;
112  if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
113  // Check candidate that can override the store.
115  cast<MemrefValue>(nextWrite.getSource()),
116  cast<MemrefValue>(write.getSource())) &&
117  checkSameValueWAW(nextWrite, write) &&
118  postDominators.postDominates(nextWrite, write)) {
119  if (firstOverwriteCandidate == nullptr ||
120  postDominators.postDominates(firstOverwriteCandidate, nextWrite))
121  firstOverwriteCandidate = nextWrite;
122  else
123  assert(
124  postDominators.postDominates(nextWrite, firstOverwriteCandidate));
125  continue;
126  }
127  }
128  if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
129  // Don't need to consider disjoint accesses.
131  cast<VectorTransferOpInterface>(write.getOperation()),
132  cast<VectorTransferOpInterface>(transferOp.getOperation()),
133  /*testDynamicValueUsingBounds=*/true))
134  continue;
135  }
136  blockingAccesses.push_back(user);
137  }
138  if (firstOverwriteCandidate == nullptr)
139  return;
140  Region *topRegion = firstOverwriteCandidate->getParentRegion();
141  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
142  assert(writeAncestor &&
143  "write op should be recursively part of the top region");
144 
145  for (Operation *access : blockingAccesses) {
146  Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
147  // TODO: if the access and write have the same ancestor we could recurse in
148  // the region to know if the access is reachable with more precision.
149  if (accessAncestor == nullptr ||
150  !isReachable(writeAncestor, accessAncestor))
151  continue;
152  if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
153  LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
154  << *accessAncestor << "\n");
155  return;
156  }
157  }
158  LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
159  << " overwritten by: " << *firstOverwriteCandidate << "\n");
160  opToErase.push_back(write.getOperation());
161 }
162 
163 /// A transfer_write candidate to storeToLoad forwarding must:
164 /// 1. Access the same memref with the same indices and vector type as the
165 /// transfer_read.
166 /// 2. Dominate the transfer_read operation.
167 /// If several candidates are available, one must be dominated by all the others
168 /// since they are all dominating the same transfer_read. We only consider the
169 /// transfer_write dominated by all the other candidates as this will be the
170 /// last transfer_write executed before the transfer_read.
171 /// If we found such a candidate we can do the forwarding if all the other
172 /// potentially aliasing ops that may reach the transfer_read are post-dominated
173 /// by the transfer_write.
174 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
175  if (read.hasOutOfBoundsDim())
176  return;
177  LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
178  << "\n");
179  SmallVector<Operation *, 8> blockingWrites;
180  vector::TransferWriteOp lastwrite = nullptr;
181  Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource()));
182  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
183  source.getUsers().end());
184  llvm::SmallDenseSet<Operation *, 32> processed;
185  while (!users.empty()) {
186  Operation *user = users.pop_back_val();
187  // If the user has already been processed skip.
188  if (!processed.insert(user).second)
189  continue;
190  if (isa<ViewLikeOpInterface>(user)) {
191  users.append(user->getUsers().begin(), user->getUsers().end());
192  continue;
193  }
194  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
195  continue;
196  if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
197  // If there is a write, but we can prove that it is disjoint we can ignore
198  // the write.
200  cast<VectorTransferOpInterface>(write.getOperation()),
201  cast<VectorTransferOpInterface>(read.getOperation()),
202  /*testDynamicValueUsingBounds=*/true))
203  continue;
205  cast<MemrefValue>(read.getSource()),
206  cast<MemrefValue>(write.getSource())) &&
207  dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
208  if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
209  lastwrite = write;
210  else
211  assert(dominators.dominates(write, lastwrite));
212  continue;
213  }
214  }
215  blockingWrites.push_back(user);
216  }
217 
218  if (lastwrite == nullptr)
219  return;
220 
221  Region *topRegion = lastwrite->getParentRegion();
222  Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
223  assert(readAncestor &&
224  "read op should be recursively part of the top region");
225 
226  for (Operation *write : blockingWrites) {
227  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
228  // TODO: if the store and read have the same ancestor we could recurse in
229  // the region to know if the read is reachable with more precision.
230  if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
231  continue;
232  if (!postDominators.postDominates(lastwrite, write)) {
233  LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
234  << *write << "\n");
235  return;
236  }
237  }
238 
239  LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
240  << " to: " << *read.getOperation() << "\n");
241  read.replaceAllUsesWith(lastwrite.getVector());
242  opToErase.push_back(read.getOperation());
243 }
244 
245 /// Converts OpFoldResults to int64_t shape without unit dims.
247  SmallVector<int64_t> reducedShape;
248  for (const auto size : mixedSizes) {
249  if (llvm::dyn_cast_if_present<Value>(size)) {
250  reducedShape.push_back(ShapedType::kDynamic);
251  continue;
252  }
253 
254  auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
255  if (value == 1)
256  continue;
257  reducedShape.push_back(value.getSExtValue());
258  }
259  return reducedShape;
260 }
261 
262 /// Drops unit dimensions from the input MemRefType.
263 static MemRefType dropUnitDims(MemRefType inputType,
264  ArrayRef<OpFoldResult> offsets,
266  ArrayRef<OpFoldResult> strides) {
267  auto targetShape = getReducedShape(sizes);
268  Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
269  targetShape, inputType, offsets, sizes, strides);
270  return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
271 }
272 
273 /// Creates a rank-reducing memref.subview op that drops unit dims from its
274 /// input. Or just returns the input if it was already without unit dims.
276  mlir::Location loc,
277  Value input) {
278  MemRefType inputType = cast<MemRefType>(input.getType());
279  SmallVector<OpFoldResult> offsets(inputType.getRank(),
280  rewriter.getIndexAttr(0));
281  SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
282  SmallVector<OpFoldResult> strides(inputType.getRank(),
283  rewriter.getIndexAttr(1));
284  MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
285 
286  if (canonicalizeStridedLayout(resultType) ==
287  canonicalizeStridedLayout(inputType))
288  return input;
289  return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
290  sizes, strides);
291 }
292 
293 /// Returns the number of dims that aren't unit dims.
295  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
296 }
297 
298 /// Trims non-scalable one dimensions from `oldType` and returns the result
299 /// type.
300 static VectorType trimNonScalableUnitDims(VectorType oldType) {
301  SmallVector<int64_t> newShape;
302  SmallVector<bool> newScalableDims;
303  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
304  if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
305  continue;
306  newShape.push_back(dimSize);
307  newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
308  }
309  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
310 }
311 
312 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
313 static FailureOr<Value>
315  vector::CreateMaskOp op) {
316  auto type = op.getType();
317  VectorType reducedType = trimNonScalableUnitDims(type);
318  if (reducedType.getRank() == type.getRank())
319  return failure();
320 
321  SmallVector<Value> reducedOperands;
322  for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
323  type.getShape(), type.getScalableDims(), op.getOperands())) {
324  if (dim == 1 && !dimIsScalable) {
325  // If the mask for the unit dim is not a constant of 1, do nothing.
326  auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
327  if (!constant || (constant.value() != 1))
328  return failure();
329  continue;
330  }
331  reducedOperands.push_back(operand);
332  }
333  return rewriter
334  .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
335  .getResult();
336 }
337 
338 namespace {
339 
340 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
341 /// inserting a memref.subview dropping those unit dims. The vector shapes are
342 /// also reduced accordingly.
343 class TransferReadDropUnitDimsPattern
344  : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
345  using MaskableOpRewritePattern::MaskableOpRewritePattern;
346 
347  FailureOr<Value>
348  matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
349  vector::MaskingOpInterface maskingOp,
350  PatternRewriter &rewriter) const override {
351  auto loc = transferReadOp.getLoc();
352  Value vector = transferReadOp.getVector();
353  VectorType vectorType = cast<VectorType>(vector.getType());
354  Value source = transferReadOp.getSource();
355  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
356  // TODO: support tensor types.
357  if (!sourceType)
358  return failure();
359  // TODO: generalize this pattern, relax the requirements here.
360  if (transferReadOp.hasOutOfBoundsDim())
361  return failure();
362  if (!transferReadOp.getPermutationMap().isMinorIdentity())
363  return failure();
364  // Check if the source shape can be further reduced.
365  int reducedRank = getReducedRank(sourceType.getShape());
366  if (reducedRank == sourceType.getRank())
367  return failure();
368  // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
369  // out.
370  if (reducedRank == 0 && maskingOp)
371  return failure();
372  // Check if the reduced vector shape matches the reduced source shape.
373  // Otherwise, this case is not supported yet.
374  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
375  if (reducedRank != reducedVectorType.getRank())
376  return failure();
377  if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
378  return getConstantIntValue(v) != static_cast<int64_t>(0);
379  }))
380  return failure();
381 
382  Value maskOp = transferReadOp.getMask();
383  if (maskOp) {
384  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
385  if (!createMaskOp)
386  return rewriter.notifyMatchFailure(
387  transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
388  "currently supported");
389  FailureOr<Value> rankReducedCreateMask =
390  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
391  if (failed(rankReducedCreateMask))
392  return failure();
393  maskOp = *rankReducedCreateMask;
394  }
395 
396  Value reducedShapeSource =
397  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
398  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
399  SmallVector<Value> zeros(reducedRank, c0);
400  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
401  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
402  Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
403  loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
404  transferReadOp.getPadding(), maskOp,
405  rewriter.getBoolArrayAttr(inBounds));
406 
407  if (maskingOp) {
408  auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
409  loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
410  maskingOp.getMask());
411  newTransferReadOp = mlir::vector::maskOperation(
412  rewriter, newTransferReadOp, shapeCastMask);
413  }
414 
415  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
416  loc, vectorType, newTransferReadOp->getResults()[0]);
417 
418  return shapeCast;
419  }
420 };
421 
422 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
423 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
424 /// vector shapes are also reduced accordingly.
425 class TransferWriteDropUnitDimsPattern
426  : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
427  using MaskableOpRewritePattern::MaskableOpRewritePattern;
428 
429  FailureOr<Value>
430  matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
431  vector::MaskingOpInterface maskingOp,
432  PatternRewriter &rewriter) const override {
433  auto loc = transferWriteOp.getLoc();
434  Value vector = transferWriteOp.getVector();
435  VectorType vectorType = cast<VectorType>(vector.getType());
436  Value source = transferWriteOp.getSource();
437  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
438  // TODO: support tensor type.
439  if (!sourceType)
440  return failure();
441  // TODO: generalize this pattern, relax the requirements here.
442  if (transferWriteOp.hasOutOfBoundsDim())
443  return failure();
444  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
445  return failure();
446  // Check if the destination shape can be further reduced.
447  int reducedRank = getReducedRank(sourceType.getShape());
448  if (reducedRank == sourceType.getRank())
449  return failure();
450  // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
451  // out.
452  if (reducedRank == 0 && maskingOp)
453  return failure();
454  // Check if the reduced vector shape matches the reduced destination shape.
455  // Otherwise, this case is not supported yet.
456  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
457  if (reducedRank != reducedVectorType.getRank())
458  return failure();
459  if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
460  return getConstantIntValue(v) != static_cast<int64_t>(0);
461  }))
462  return failure();
463 
464  Value maskOp = transferWriteOp.getMask();
465  if (maskOp) {
466  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
467  if (!createMaskOp)
468  return rewriter.notifyMatchFailure(
469  transferWriteOp,
470  "unsupported mask op, only 'vector.create_mask' is "
471  "currently supported");
472  FailureOr<Value> rankReducedCreateMask =
473  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
474  if (failed(rankReducedCreateMask))
475  return failure();
476  maskOp = *rankReducedCreateMask;
477  }
478  Value reducedShapeSource =
479  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
480  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
481  SmallVector<Value> zeros(reducedRank, c0);
482  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
483  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
484  auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
485  loc, reducedVectorType, vector);
486  Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
487  loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
488  maskOp, rewriter.getBoolArrayAttr(inBounds));
489 
490  if (maskingOp) {
491  auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
492  loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
493  maskingOp.getMask());
494  newXferWrite =
495  mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
496  }
497 
498  if (transferWriteOp.hasPureTensorSemantics())
499  return newXferWrite->getResults()[0];
500 
501  // With Memref semantics, there's no return value. Use empty value to signal
502  // success.
503  return Value();
504  }
505 };
506 
507 } // namespace
508 
509 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
510 /// input starting at `firstDimToCollapse`.
512  Value input, int64_t firstDimToCollapse) {
513  ShapedType inputType = cast<ShapedType>(input.getType());
514  if (inputType.getRank() == 1)
515  return input;
516  SmallVector<ReassociationIndices> reassociation;
517  for (int64_t i = 0; i < firstDimToCollapse; ++i)
518  reassociation.push_back(ReassociationIndices{i});
519  ReassociationIndices collapsedIndices;
520  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
521  collapsedIndices.push_back(i);
522  reassociation.push_back(collapsedIndices);
523  return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
524 }
525 
526 /// Returns the new indices that collapses the inner dimensions starting from
527 /// the `firstDimToCollapse` dimension.
529  Location loc,
530  ArrayRef<int64_t> shape,
531  ValueRange indices,
532  int64_t firstDimToCollapse) {
533  assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
534 
535  // If all the collapsed indices are zero then no extra logic is needed.
536  // Otherwise, a new offset/index has to be computed.
537  SmallVector<Value> indicesAfterCollapsing(
538  indices.begin(), indices.begin() + firstDimToCollapse);
539  SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
540  indices.end());
541  if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
542  indicesAfterCollapsing.push_back(indicesToCollapse[0]);
543  return indicesAfterCollapsing;
544  }
545 
546  // Compute the remaining trailing index/offset required for reading from
547  // the collapsed memref:
548  //
549  // offset = 0
550  // for (i = firstDimToCollapse; i < outputRank; ++i)
551  // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
552  //
553  // For this example:
554  // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
555  // memref<1x43x2xi32>, vector<1x2xi32>
556  // which would be collapsed to:
557  // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
558  // memref<1x86xi32>, vector<2xi32>
559  // one would get the following offset:
560  // %offset = %arg0 * 43
561  OpFoldResult collapsedOffset =
562  rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
563 
564  auto collapsedStrides = computeSuffixProduct(
565  ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
566 
567  // Compute the collapsed offset.
568  auto &&[collapsedExpr, collapsedVals] =
569  computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
570  collapsedOffset = affine::makeComposedFoldedAffineApply(
571  rewriter, loc, collapsedExpr, collapsedVals);
572 
573  if (collapsedOffset.is<Value>()) {
574  indicesAfterCollapsing.push_back(collapsedOffset.get<Value>());
575  } else {
576  indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
577  loc, *getConstantIntValue(collapsedOffset)));
578  }
579 
580  return indicesAfterCollapsing;
581 }
582 
583 namespace {
584 
585 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
586 /// memref.collapse_shape on the source so that the resulting
587 /// vector.transfer_read has a 1D source. Requires the source shape to be
588 /// already reduced i.e. without unit dims.
589 ///
590 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
591 /// the trailing dimension of the vector read is smaller than the provided
592 /// bitwidth.
593 class FlattenContiguousRowMajorTransferReadPattern
594  : public OpRewritePattern<vector::TransferReadOp> {
595 public:
596  FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
597  unsigned vectorBitwidth,
598  PatternBenefit benefit)
599  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
600  targetVectorBitwidth(vectorBitwidth) {}
601 
602  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
603  PatternRewriter &rewriter) const override {
604  auto loc = transferReadOp.getLoc();
605  Value vector = transferReadOp.getVector();
606  VectorType vectorType = cast<VectorType>(vector.getType());
607  auto source = transferReadOp.getSource();
608  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
609 
610  // 0. Check pre-conditions
611  // Contiguity check is valid on tensors only.
612  if (!sourceType)
613  return failure();
614  // If this is already 0D/1D, there's nothing to do.
615  if (vectorType.getRank() <= 1)
616  return failure();
617  if (!vectorType.getElementType().isSignlessIntOrFloat())
618  return failure();
619  unsigned trailingVectorDimBitwidth =
620  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
621  if (trailingVectorDimBitwidth >= targetVectorBitwidth)
622  return failure();
623  if (!vector::isContiguousSlice(sourceType, vectorType))
624  return failure();
625  // TODO: generalize this pattern, relax the requirements here.
626  if (transferReadOp.hasOutOfBoundsDim())
627  return failure();
628  if (!transferReadOp.getPermutationMap().isMinorIdentity())
629  return failure();
630  if (transferReadOp.getMask())
631  return failure();
632 
633  int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
634 
635  // 1. Collapse the source memref
636  Value collapsedSource =
637  collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
638  MemRefType collapsedSourceType =
639  cast<MemRefType>(collapsedSource.getType());
640  int64_t collapsedRank = collapsedSourceType.getRank();
641  assert(collapsedRank == firstDimToCollapse + 1);
642 
643  // 2. Generate input args for a new vector.transfer_read that will read
644  // from the collapsed memref.
645  // 2.1. New dim exprs + affine map
647  getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
648  auto collapsedMap =
649  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
650 
651  // 2.2 New indices
652  SmallVector<Value> collapsedIndices =
653  getCollapsedIndices(rewriter, loc, sourceType.getShape(),
654  transferReadOp.getIndices(), firstDimToCollapse);
655 
656  // 3. Create new vector.transfer_read that reads from the collapsed memref
657  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
658  vectorType.getElementType());
659  vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
660  loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
661  flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
662 
663  // 4. Replace the old transfer_read with the new one reading from the
664  // collapsed shape
665  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
666  transferReadOp, cast<VectorType>(vector.getType()), flatRead);
667  return success();
668  }
669 
670 private:
671  // Minimum bitwidth that the trailing vector dimension should have after
672  // flattening.
673  unsigned targetVectorBitwidth;
674 };
675 
676 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
677 /// memref.collapse_shape on the source so that the resulting
678 /// vector.transfer_write has a 1D source. Requires the source shape to be
679 /// already reduced i.e. without unit dims.
680 ///
681 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
682 /// the trailing dimension of the vector read is smaller than the provided
683 /// bitwidth.
684 class FlattenContiguousRowMajorTransferWritePattern
685  : public OpRewritePattern<vector::TransferWriteOp> {
686 public:
687  FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
688  unsigned vectorBitwidth,
689  PatternBenefit benefit)
690  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
691  targetVectorBitwidth(vectorBitwidth) {}
692 
693  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
694  PatternRewriter &rewriter) const override {
695  auto loc = transferWriteOp.getLoc();
696  Value vector = transferWriteOp.getVector();
697  VectorType vectorType = cast<VectorType>(vector.getType());
698  Value source = transferWriteOp.getSource();
699  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
700 
701  // 0. Check pre-conditions
702  // Contiguity check is valid on tensors only.
703  if (!sourceType)
704  return failure();
705  // If this is already 0D/1D, there's nothing to do.
706  if (vectorType.getRank() <= 1)
707  // Already 0D/1D, nothing to do.
708  return failure();
709  if (!vectorType.getElementType().isSignlessIntOrFloat())
710  return failure();
711  unsigned trailingVectorDimBitwidth =
712  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
713  if (trailingVectorDimBitwidth >= targetVectorBitwidth)
714  return failure();
715  if (!vector::isContiguousSlice(sourceType, vectorType))
716  return failure();
717  // TODO: generalize this pattern, relax the requirements here.
718  if (transferWriteOp.hasOutOfBoundsDim())
719  return failure();
720  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
721  return failure();
722  if (transferWriteOp.getMask())
723  return failure();
724 
725  int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
726 
727  // 1. Collapse the source memref
728  Value collapsedSource =
729  collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
730  MemRefType collapsedSourceType =
731  cast<MemRefType>(collapsedSource.getType());
732  int64_t collapsedRank = collapsedSourceType.getRank();
733  assert(collapsedRank == firstDimToCollapse + 1);
734 
735  // 2. Generate input args for a new vector.transfer_read that will read
736  // from the collapsed memref.
737  // 2.1. New dim exprs + affine map
739  getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
740  auto collapsedMap =
741  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
742 
743  // 2.2 New indices
744  SmallVector<Value> collapsedIndices =
745  getCollapsedIndices(rewriter, loc, sourceType.getShape(),
746  transferWriteOp.getIndices(), firstDimToCollapse);
747 
748  // 3. Create new vector.transfer_write that writes to the collapsed memref
749  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
750  vectorType.getElementType());
751  Value flatVector =
752  rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
753  vector::TransferWriteOp flatWrite =
754  rewriter.create<vector::TransferWriteOp>(
755  loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
756  flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
757 
758  // 4. Replace the old transfer_write with the new one writing the
759  // collapsed shape
760  rewriter.eraseOp(transferWriteOp);
761  return success();
762  }
763 
764 private:
765  // Minimum bitwidth that the trailing vector dimension should have after
766  // flattening.
767  unsigned targetVectorBitwidth;
768 };
769 
770 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
771 /// to `memref.load` patterns. The `match` method is shared for both
772 /// `vector.extract` and `vector.extract_element`.
773 template <class VectorExtractOp>
774 class RewriteScalarExtractOfTransferReadBase
775  : public OpRewritePattern<VectorExtractOp> {
777 
778 public:
779  RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
780  PatternBenefit benefit,
781  bool allowMultipleUses)
782  : Base::OpRewritePattern(context, benefit),
783  allowMultipleUses(allowMultipleUses) {}
784 
785  LogicalResult match(VectorExtractOp extractOp) const override {
786  auto xferOp =
787  extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
788  if (!xferOp)
789  return failure();
790  // Check that we are extracting a scalar and not a sub-vector.
791  if (isa<VectorType>(extractOp.getResult().getType()))
792  return failure();
793  // If multiple uses are not allowed, check if xfer has a single use.
794  if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
795  return failure();
796  // If multiple uses are allowed, check if all the xfer uses are extract ops.
797  if (allowMultipleUses &&
798  !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
799  return isa<vector::ExtractOp, vector::ExtractElementOp>(
800  use.getOwner());
801  }))
802  return failure();
803  // Mask not supported.
804  if (xferOp.getMask())
805  return failure();
806  // Map not supported.
807  if (!xferOp.getPermutationMap().isMinorIdentity())
808  return failure();
809  // Cannot rewrite if the indices may be out of bounds.
810  if (xferOp.hasOutOfBoundsDim())
811  return failure();
812  return success();
813  }
814 
815 private:
816  bool allowMultipleUses;
817 };
818 
819 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
820 ///
821 /// All the users of the transfer op must be either `vector.extractelement` or
822 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
823 /// transfer ops with any number of users. Otherwise, rewrite only if the
824 /// extract op is the single user of the transfer op. Rewriting a single
825 /// vector load with multiple scalar loads may negatively affect performance.
826 class RewriteScalarExtractElementOfTransferRead
827  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
828  using RewriteScalarExtractOfTransferReadBase::
829  RewriteScalarExtractOfTransferReadBase;
830 
831  void rewrite(vector::ExtractElementOp extractOp,
832  PatternRewriter &rewriter) const override {
833  // Construct scalar load.
834  auto loc = extractOp.getLoc();
835  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
836  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
837  xferOp.getIndices().end());
838  if (extractOp.getPosition()) {
839  AffineExpr sym0, sym1;
840  bindSymbols(extractOp.getContext(), sym0, sym1);
842  rewriter, loc, sym0 + sym1,
843  {newIndices[newIndices.size() - 1], extractOp.getPosition()});
844  if (ofr.is<Value>()) {
845  newIndices[newIndices.size() - 1] = ofr.get<Value>();
846  } else {
847  newIndices[newIndices.size() - 1] =
848  rewriter.create<arith::ConstantIndexOp>(loc,
849  *getConstantIntValue(ofr));
850  }
851  }
852  if (isa<MemRefType>(xferOp.getSource().getType())) {
853  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
854  newIndices);
855  } else {
856  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
857  extractOp, xferOp.getSource(), newIndices);
858  }
859  }
860 };
861 
862 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
863 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
864 ///
865 /// All the users of the transfer op must be either `vector.extractelement` or
866 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
867 /// transfer ops with any number of users. Otherwise, rewrite only if the
868 /// extract op is the single user of the transfer op. Rewriting a single
869 /// vector load with multiple scalar loads may negatively affect performance.
870 class RewriteScalarExtractOfTransferRead
871  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
872  using RewriteScalarExtractOfTransferReadBase::
873  RewriteScalarExtractOfTransferReadBase;
874 
875  void rewrite(vector::ExtractOp extractOp,
876  PatternRewriter &rewriter) const override {
877  // Construct scalar load.
878  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
879  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
880  xferOp.getIndices().end());
881  for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
882  assert(pos.is<Attribute>() && "Unexpected non-constant index");
883  int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
884  int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
886  rewriter, extractOp.getLoc(),
887  rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
888  if (ofr.is<Value>()) {
889  newIndices[idx] = ofr.get<Value>();
890  } else {
891  newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
892  extractOp.getLoc(), *getConstantIntValue(ofr));
893  }
894  }
895  if (isa<MemRefType>(xferOp.getSource().getType())) {
896  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
897  newIndices);
898  } else {
899  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
900  extractOp, xferOp.getSource(), newIndices);
901  }
902  }
903 };
904 
905 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
906 /// to memref.store.
907 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
909 
910  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
911  PatternRewriter &rewriter) const override {
912  // Must be a scalar write.
913  auto vecType = xferOp.getVectorType();
914  if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
915  return failure();
916  // Mask not supported.
917  if (xferOp.getMask())
918  return failure();
919  // Map not supported.
920  if (!xferOp.getPermutationMap().isMinorIdentity())
921  return failure();
922  // Only float and integer element types are supported.
923  Value scalar;
924  if (vecType.getRank() == 0) {
925  // vector.extract does not support vector<f32> etc., so use
926  // vector.extractelement instead.
927  scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
928  xferOp.getVector());
929  } else {
930  SmallVector<int64_t> pos(vecType.getRank(), 0);
931  scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
932  xferOp.getVector(), pos);
933  }
934  // Construct a scalar store.
935  if (isa<MemRefType>(xferOp.getSource().getType())) {
936  rewriter.replaceOpWithNewOp<memref::StoreOp>(
937  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
938  } else {
939  rewriter.replaceOpWithNewOp<tensor::InsertOp>(
940  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
941  }
942  return success();
943  }
944 };
945 
946 } // namespace
947 
949  Operation *rootOp) {
950  TransferOptimization opt(rewriter, rootOp);
951  // Run store to load forwarding first since it can expose more dead store
952  // opportunity.
953  rootOp->walk([&](vector::TransferReadOp read) {
954  if (isa<MemRefType>(read.getShapedType()))
955  opt.storeToLoadForwarding(read);
956  });
957  opt.removeDeadOp();
958  rootOp->walk([&](vector::TransferWriteOp write) {
959  if (isa<MemRefType>(write.getShapedType()))
960  opt.deadStoreOp(write);
961  });
962  opt.removeDeadOp();
963 }
964 
966  RewritePatternSet &patterns, PatternBenefit benefit,
967  bool allowMultipleUses) {
968  patterns.add<RewriteScalarExtractElementOfTransferRead,
969  RewriteScalarExtractOfTransferRead>(patterns.getContext(),
970  benefit, allowMultipleUses);
971  patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
972 }
973 
975  RewritePatternSet &patterns, PatternBenefit benefit) {
976  patterns
977  .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
978  patterns.getContext(), benefit);
980 }
981 
983  RewritePatternSet &patterns, unsigned targetVectorBitwidth,
984  PatternBenefit benefit) {
985  patterns.add<FlattenContiguousRowMajorTransferReadPattern,
986  FlattenContiguousRowMajorTransferWritePattern>(
987  patterns.getContext(), targetVectorBitwidth, benefit);
988  populateShapeCastFoldingPatterns(patterns, benefit);
989  populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
990 }
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input)
Creates a rank-reducing memref.subview op that drops unit dims from its input.
static int getReducedRank(ArrayRef< int64_t > shape)
Returns the number of dims that aren't unit dims.
static SmallVector< int64_t > getReducedShape(ArrayRef< OpFoldResult > mixedSizes)
Converts OpFoldResults to int64_t shape without unit dims.
static VectorType trimNonScalableUnitDims(VectorType oldType)
Trims non-scalable one dimensions from oldType and returns the result type.
static FailureOr< Value > createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, vector::CreateMaskOp op)
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
static SmallVector< Value > getCollapsedIndices(RewriterBase &rewriter, Location loc, ArrayRef< int64_t > shape, ValueRange indices, int64_t firstDimToCollapse)
Returns the new indices that collapses the inner dimensions starting from the firstDimToCollapse dime...
#define DBGS()
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
Drops unit dimensions from the input MemRefType.
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
Definition: Block.cpp:355
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:408
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:97
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:310
A class for computing basic dominance information.
Definition: Dominance.h:140
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
A class for computing basic postdominance information.
Definition: Dominance.h:197
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
Definition: MemRefUtils.h:111
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, unsigned targetVectorBitwidth=std::numeric_limits< unsigned >::max(), PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:283
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:157