MLIR  21.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
24 #include "mlir/IR/Dominance.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Debug.h"
29 
30 #define DEBUG_TYPE "vector-transfer-opt"
31 
32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
33 
34 using namespace mlir;
35 
36 /// Return the ancestor op in the region or nullptr if the region is not
37 /// an ancestor of the op.
39  for (; op != nullptr && op->getParentRegion() != region;
40  op = op->getParentOp())
41  ;
42  return op;
43 }
44 
45 namespace {
46 
47 class TransferOptimization {
48 public:
49  TransferOptimization(RewriterBase &rewriter, Operation *op)
50  : rewriter(rewriter), dominators(op), postDominators(op) {}
51  void deadStoreOp(vector::TransferWriteOp);
52  void storeToLoadForwarding(vector::TransferReadOp);
53  void removeDeadOp() {
54  for (Operation *op : opToErase)
55  rewriter.eraseOp(op);
56  opToErase.clear();
57  }
58 
59 private:
60  RewriterBase &rewriter;
61  bool isReachable(Operation *start, Operation *dest);
62  DominanceInfo dominators;
63  PostDominanceInfo postDominators;
64  std::vector<Operation *> opToErase;
65 };
66 
67 } // namespace
68 /// Return true if there is a path from start operation to dest operation,
69 /// otherwise return false. The operations have to be in the same region.
70 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
71  assert(start->getParentRegion() == dest->getParentRegion() &&
72  "This function only works for ops i the same region");
73  // Simple case where the start op dominate the destination.
74  if (dominators.dominates(start, dest))
75  return true;
76  return start->getBlock()->isReachable(dest->getBlock());
77 }
78 
79 /// For transfer_write to overwrite fully another transfer_write must:
80 /// 1. Access the same memref with the same indices and vector type.
81 /// 2. Post-dominate the other transfer_write operation.
82 /// If several candidates are available, one must be post-dominated by all the
83 /// others since they are all post-dominating the same transfer_write. We only
84 /// consider the transfer_write post-dominated by all the other candidates as
85 /// this will be the first transfer_write executed after the potentially dead
86 /// transfer_write.
87 /// If we found such an overwriting transfer_write we know that the original
88 /// transfer_write is dead if all reads that can be reached from the potentially
89 /// dead transfer_write are dominated by the overwriting transfer_write.
90 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
91  LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
92  << "\n");
93  llvm::SmallVector<Operation *, 8> blockingAccesses;
94  Operation *firstOverwriteCandidate = nullptr;
95  Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
96  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
97  source.getUsers().end());
98  llvm::SmallDenseSet<Operation *, 32> processed;
99  while (!users.empty()) {
100  Operation *user = users.pop_back_val();
101  // If the user has already been processed skip.
102  if (!processed.insert(user).second)
103  continue;
104  if (isa<ViewLikeOpInterface>(user)) {
105  users.append(user->getUsers().begin(), user->getUsers().end());
106  continue;
107  }
108  if (isMemoryEffectFree(user))
109  continue;
110  if (user == write.getOperation())
111  continue;
112  if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
113  // Check candidate that can override the store.
115  cast<MemrefValue>(nextWrite.getBase()),
116  cast<MemrefValue>(write.getBase())) &&
117  checkSameValueWAW(nextWrite, write) &&
118  postDominators.postDominates(nextWrite, write)) {
119  if (firstOverwriteCandidate == nullptr ||
120  postDominators.postDominates(firstOverwriteCandidate, nextWrite))
121  firstOverwriteCandidate = nextWrite;
122  else
123  assert(
124  postDominators.postDominates(nextWrite, firstOverwriteCandidate));
125  continue;
126  }
127  }
128  if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
129  // Don't need to consider disjoint accesses.
131  cast<VectorTransferOpInterface>(write.getOperation()),
132  cast<VectorTransferOpInterface>(transferOp.getOperation()),
133  /*testDynamicValueUsingBounds=*/true))
134  continue;
135  }
136  blockingAccesses.push_back(user);
137  }
138  if (firstOverwriteCandidate == nullptr)
139  return;
140  Region *topRegion = firstOverwriteCandidate->getParentRegion();
141  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
142  assert(writeAncestor &&
143  "write op should be recursively part of the top region");
144 
145  for (Operation *access : blockingAccesses) {
146  Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
147  // TODO: if the access and write have the same ancestor we could recurse in
148  // the region to know if the access is reachable with more precision.
149  if (accessAncestor == nullptr ||
150  !isReachable(writeAncestor, accessAncestor))
151  continue;
152  if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
153  LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
154  << *accessAncestor << "\n");
155  return;
156  }
157  }
158  LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
159  << " overwritten by: " << *firstOverwriteCandidate << "\n");
160  opToErase.push_back(write.getOperation());
161 }
162 
163 /// A transfer_write candidate to storeToLoad forwarding must:
164 /// 1. Access the same memref with the same indices and vector type as the
165 /// transfer_read.
166 /// 2. Dominate the transfer_read operation.
167 /// If several candidates are available, one must be dominated by all the others
168 /// since they are all dominating the same transfer_read. We only consider the
169 /// transfer_write dominated by all the other candidates as this will be the
170 /// last transfer_write executed before the transfer_read.
171 /// If we found such a candidate we can do the forwarding if all the other
172 /// potentially aliasing ops that may reach the transfer_read are post-dominated
173 /// by the transfer_write.
174 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
175  if (read.hasOutOfBoundsDim())
176  return;
177  LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
178  << "\n");
179  SmallVector<Operation *, 8> blockingWrites;
180  vector::TransferWriteOp lastwrite = nullptr;
181  Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
182  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
183  source.getUsers().end());
184  llvm::SmallDenseSet<Operation *, 32> processed;
185  while (!users.empty()) {
186  Operation *user = users.pop_back_val();
187  // If the user has already been processed skip.
188  if (!processed.insert(user).second)
189  continue;
190  if (isa<ViewLikeOpInterface>(user)) {
191  users.append(user->getUsers().begin(), user->getUsers().end());
192  continue;
193  }
194  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
195  continue;
196  if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
197  // If there is a write, but we can prove that it is disjoint we can ignore
198  // the write.
200  cast<VectorTransferOpInterface>(write.getOperation()),
201  cast<VectorTransferOpInterface>(read.getOperation()),
202  /*testDynamicValueUsingBounds=*/true))
203  continue;
205  cast<MemrefValue>(read.getBase()),
206  cast<MemrefValue>(write.getBase())) &&
207  dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
208  if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
209  lastwrite = write;
210  else
211  assert(dominators.dominates(write, lastwrite));
212  continue;
213  }
214  }
215  blockingWrites.push_back(user);
216  }
217 
218  if (lastwrite == nullptr)
219  return;
220 
221  Region *topRegion = lastwrite->getParentRegion();
222  Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
223  assert(readAncestor &&
224  "read op should be recursively part of the top region");
225 
226  for (Operation *write : blockingWrites) {
227  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
228  // TODO: if the store and read have the same ancestor we could recurse in
229  // the region to know if the read is reachable with more precision.
230  if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
231  continue;
232  if (!postDominators.postDominates(lastwrite, write)) {
233  LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
234  << *write << "\n");
235  return;
236  }
237  }
238 
239  LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
240  << " to: " << *read.getOperation() << "\n");
241  read.replaceAllUsesWith(lastwrite.getVector());
242  opToErase.push_back(read.getOperation());
243 }
244 
245 /// Converts OpFoldResults to int64_t shape without unit dims.
247  SmallVector<int64_t> reducedShape;
248  for (const auto size : mixedSizes) {
249  if (llvm::dyn_cast_if_present<Value>(size)) {
250  reducedShape.push_back(ShapedType::kDynamic);
251  continue;
252  }
253 
254  auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
255  if (value == 1)
256  continue;
257  reducedShape.push_back(value.getSExtValue());
258  }
259  return reducedShape;
260 }
261 
262 /// Drops unit dimensions from the input MemRefType.
263 static MemRefType dropUnitDims(MemRefType inputType,
264  ArrayRef<OpFoldResult> offsets,
266  ArrayRef<OpFoldResult> strides) {
267  auto targetShape = getReducedShape(sizes);
268  MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
269  targetShape, inputType, offsets, sizes, strides);
270  return rankReducedType.canonicalizeStridedLayout();
271 }
272 
273 /// Creates a rank-reducing memref.subview op that drops unit dims from its
274 /// input. Or just returns the input if it was already without unit dims.
276  mlir::Location loc,
277  Value input) {
278  MemRefType inputType = cast<MemRefType>(input.getType());
279  SmallVector<OpFoldResult> offsets(inputType.getRank(),
280  rewriter.getIndexAttr(0));
281  SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
282  SmallVector<OpFoldResult> strides(inputType.getRank(),
283  rewriter.getIndexAttr(1));
284  MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
285 
286  if (resultType.canonicalizeStridedLayout() ==
287  inputType.canonicalizeStridedLayout())
288  return input;
289  return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
290  sizes, strides);
291 }
292 
293 /// Returns the number of dims that aren't unit dims.
295  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
296 }
297 
298 /// Trims non-scalable one dimensions from `oldType` and returns the result
299 /// type.
300 static VectorType trimNonScalableUnitDims(VectorType oldType) {
301  SmallVector<int64_t> newShape;
302  SmallVector<bool> newScalableDims;
303  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
304  if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
305  continue;
306  newShape.push_back(dimSize);
307  newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
308  }
309  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
310 }
311 
312 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
313 static FailureOr<Value>
315  vector::CreateMaskOp op) {
316  auto type = op.getType();
317  VectorType reducedType = trimNonScalableUnitDims(type);
318  if (reducedType.getRank() == type.getRank())
319  return failure();
320 
321  SmallVector<Value> reducedOperands;
322  for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
323  type.getShape(), type.getScalableDims(), op.getOperands())) {
324  if (dim == 1 && !dimIsScalable) {
325  // If the mask for the unit dim is not a constant of 1, do nothing.
326  auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
327  if (!constant || (constant.value() != 1))
328  return failure();
329  continue;
330  }
331  reducedOperands.push_back(operand);
332  }
333  return rewriter
334  .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
335  .getResult();
336 }
337 
338 namespace {
339 
340 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
341 /// inserting a memref.subview dropping those unit dims. The vector shapes are
342 /// also reduced accordingly.
343 class TransferReadDropUnitDimsPattern
344  : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
345  using MaskableOpRewritePattern::MaskableOpRewritePattern;
346 
347  FailureOr<Value>
348  matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
349  vector::MaskingOpInterface maskingOp,
350  PatternRewriter &rewriter) const override {
351  auto loc = transferReadOp.getLoc();
352  Value vector = transferReadOp.getVector();
353  VectorType vectorType = cast<VectorType>(vector.getType());
354  Value source = transferReadOp.getBase();
355  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
356  // TODO: support tensor types.
357  if (!sourceType)
358  return failure();
359  // TODO: generalize this pattern, relax the requirements here.
360  if (transferReadOp.hasOutOfBoundsDim())
361  return failure();
362  if (!transferReadOp.getPermutationMap().isMinorIdentity())
363  return failure();
364  // Check if the source shape can be further reduced.
365  int reducedRank = getReducedRank(sourceType.getShape());
366  if (reducedRank == sourceType.getRank())
367  return failure();
368  // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
369  // out.
370  if (reducedRank == 0 && maskingOp)
371  return failure();
372  // Check if the reduced vector shape matches the reduced source shape.
373  // Otherwise, this case is not supported yet.
374  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
375  if (reducedRank != reducedVectorType.getRank())
376  return failure();
377  if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
378  return getConstantIntValue(v) != static_cast<int64_t>(0);
379  }))
380  return failure();
381 
382  Value maskOp = transferReadOp.getMask();
383  if (maskOp) {
384  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
385  if (!createMaskOp)
386  return rewriter.notifyMatchFailure(
387  transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
388  "currently supported");
389  FailureOr<Value> rankReducedCreateMask =
390  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
391  if (failed(rankReducedCreateMask))
392  return failure();
393  maskOp = *rankReducedCreateMask;
394  }
395 
396  Value reducedShapeSource =
397  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
398  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
399  SmallVector<Value> zeros(reducedRank, c0);
400  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
401  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
402  Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
403  loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
404  transferReadOp.getPadding(), maskOp,
405  rewriter.getBoolArrayAttr(inBounds));
406 
407  if (maskingOp) {
408  auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
409  loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
410  maskingOp.getMask());
411  newTransferReadOp = mlir::vector::maskOperation(
412  rewriter, newTransferReadOp, shapeCastMask);
413  }
414 
415  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
416  loc, vectorType, newTransferReadOp->getResults()[0]);
417 
418  return shapeCast;
419  }
420 };
421 
422 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
423 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
424 /// vector shapes are also reduced accordingly.
425 class TransferWriteDropUnitDimsPattern
426  : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
427  using MaskableOpRewritePattern::MaskableOpRewritePattern;
428 
429  FailureOr<Value>
430  matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
431  vector::MaskingOpInterface maskingOp,
432  PatternRewriter &rewriter) const override {
433  auto loc = transferWriteOp.getLoc();
434  Value vector = transferWriteOp.getVector();
435  VectorType vectorType = cast<VectorType>(vector.getType());
436  Value source = transferWriteOp.getBase();
437  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
438  // TODO: support tensor type.
439  if (!sourceType)
440  return failure();
441  // TODO: generalize this pattern, relax the requirements here.
442  if (transferWriteOp.hasOutOfBoundsDim())
443  return failure();
444  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
445  return failure();
446  // Check if the destination shape can be further reduced.
447  int reducedRank = getReducedRank(sourceType.getShape());
448  if (reducedRank == sourceType.getRank())
449  return failure();
450  // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
451  // out.
452  if (reducedRank == 0 && maskingOp)
453  return failure();
454  // Check if the reduced vector shape matches the reduced destination shape.
455  // Otherwise, this case is not supported yet.
456  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
457  if (reducedRank != reducedVectorType.getRank())
458  return failure();
459  if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
460  return getConstantIntValue(v) != static_cast<int64_t>(0);
461  }))
462  return failure();
463 
464  Value maskOp = transferWriteOp.getMask();
465  if (maskOp) {
466  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
467  if (!createMaskOp)
468  return rewriter.notifyMatchFailure(
469  transferWriteOp,
470  "unsupported mask op, only 'vector.create_mask' is "
471  "currently supported");
472  FailureOr<Value> rankReducedCreateMask =
473  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
474  if (failed(rankReducedCreateMask))
475  return failure();
476  maskOp = *rankReducedCreateMask;
477  }
478  Value reducedShapeSource =
479  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
480  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
481  SmallVector<Value> zeros(reducedRank, c0);
482  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
483  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
484  auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
485  loc, reducedVectorType, vector);
486  Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
487  loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
488  maskOp, rewriter.getBoolArrayAttr(inBounds));
489 
490  if (maskingOp) {
491  auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
492  loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
493  maskingOp.getMask());
494  newXferWrite =
495  mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
496  }
497 
498  if (transferWriteOp.hasPureTensorSemantics())
499  return newXferWrite->getResults()[0];
500 
501  // With Memref semantics, there's no return value. Use empty value to signal
502  // success.
503  return Value();
504  }
505 };
506 
507 } // namespace
508 
509 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
510 /// input starting at `firstDimToCollapse`.
512  Value input, int64_t firstDimToCollapse) {
513  ShapedType inputType = cast<ShapedType>(input.getType());
514  if (inputType.getRank() == 1)
515  return input;
516  SmallVector<ReassociationIndices> reassociation;
517  for (int64_t i = 0; i < firstDimToCollapse; ++i)
518  reassociation.push_back(ReassociationIndices{i});
519  ReassociationIndices collapsedIndices;
520  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
521  collapsedIndices.push_back(i);
522  reassociation.push_back(collapsedIndices);
523  return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
524 }
525 
526 /// Returns the new indices that collapses the inner dimensions starting from
527 /// the `firstDimToCollapse` dimension.
529  Location loc,
530  ArrayRef<int64_t> shape,
531  ValueRange indices,
532  int64_t firstDimToCollapse) {
533  assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
534 
535  // If all the collapsed indices are zero then no extra logic is needed.
536  // Otherwise, a new offset/index has to be computed.
537  SmallVector<Value> indicesAfterCollapsing(
538  indices.begin(), indices.begin() + firstDimToCollapse);
539  SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
540  indices.end());
541  if (llvm::all_of(indicesToCollapse, isZeroInteger)) {
542  indicesAfterCollapsing.push_back(indicesToCollapse[0]);
543  return indicesAfterCollapsing;
544  }
545 
546  // Compute the remaining trailing index/offset required for reading from
547  // the collapsed memref:
548  //
549  // offset = 0
550  // for (i = firstDimToCollapse; i < outputRank; ++i)
551  // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
552  //
553  // For this example:
554  // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
555  // memref<1x43x2xi32>, vector<1x2xi32>
556  // which would be collapsed to:
557  // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
558  // memref<1x86xi32>, vector<2xi32>
559  // one would get the following offset:
560  // %offset = %arg0 * 43
561  OpFoldResult collapsedOffset =
562  rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
563 
564  auto collapsedStrides = computeSuffixProduct(
565  ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
566 
567  // Compute the collapsed offset.
568  auto &&[collapsedExpr, collapsedVals] =
569  computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
570  collapsedOffset = affine::makeComposedFoldedAffineApply(
571  rewriter, loc, collapsedExpr, collapsedVals);
572 
573  if (auto value = dyn_cast<Value>(collapsedOffset)) {
574  indicesAfterCollapsing.push_back(value);
575  } else {
576  indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
577  loc, *getConstantIntValue(collapsedOffset)));
578  }
579 
580  return indicesAfterCollapsing;
581 }
582 
583 namespace {
584 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
585 /// memref.collapse_shape on the source so that the resulting
586 /// vector.transfer_read has a 1D source. Requires the source shape to be
587 /// already reduced i.e. without unit dims.
588 ///
589 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
590 /// the trailing dimension of the vector read is smaller than the provided
591 /// bitwidth.
592 class FlattenContiguousRowMajorTransferReadPattern
593  : public OpRewritePattern<vector::TransferReadOp> {
594 public:
595  FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
596  unsigned vectorBitwidth,
597  PatternBenefit benefit)
598  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
599  targetVectorBitwidth(vectorBitwidth) {}
600 
601  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
602  PatternRewriter &rewriter) const override {
603  auto loc = transferReadOp.getLoc();
604  Value vector = transferReadOp.getVector();
605  VectorType vectorType = cast<VectorType>(vector.getType());
606  auto source = transferReadOp.getBase();
607  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
608 
609  // 0. Check pre-conditions
610  // Contiguity check is valid on tensors only.
611  if (!sourceType)
612  return failure();
613  // If this is already 0D/1D, there's nothing to do.
614  if (vectorType.getRank() <= 1)
615  return failure();
616  if (!vectorType.getElementType().isSignlessIntOrFloat())
617  return failure();
618  unsigned trailingVectorDimBitwidth =
619  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
620  if (trailingVectorDimBitwidth >= targetVectorBitwidth)
621  return failure();
622  if (!vector::isContiguousSlice(sourceType, vectorType))
623  return failure();
624  // TODO: generalize this pattern, relax the requirements here.
625  if (transferReadOp.hasOutOfBoundsDim())
626  return failure();
627  if (!transferReadOp.getPermutationMap().isMinorIdentity())
628  return failure();
629  if (transferReadOp.getMask())
630  return failure();
631 
632  // Determine the first memref dimension to collapse - just enough so we can
633  // read a flattened vector.
634  int64_t firstDimToCollapse =
635  sourceType.getRank() -
636  vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
637 
638  // 1. Collapse the source memref
639  Value collapsedSource =
640  collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
641  MemRefType collapsedSourceType =
642  cast<MemRefType>(collapsedSource.getType());
643  int64_t collapsedRank = collapsedSourceType.getRank();
644  assert(collapsedRank == firstDimToCollapse + 1);
645 
646  // 2. Generate input args for a new vector.transfer_read that will read
647  // from the collapsed memref.
648  // 2.1. New dim exprs + affine map
650  getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
651  auto collapsedMap =
652  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
653 
654  // 2.2 New indices
655  SmallVector<Value> collapsedIndices =
656  getCollapsedIndices(rewriter, loc, sourceType.getShape(),
657  transferReadOp.getIndices(), firstDimToCollapse);
658 
659  // 3. Create new vector.transfer_read that reads from the collapsed memref
660  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
661  vectorType.getElementType());
662  vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
663  loc, flatVectorType, collapsedSource, collapsedIndices,
664  transferReadOp.getPadding(), collapsedMap);
665  flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
666 
667  // 4. Replace the old transfer_read with the new one reading from the
668  // collapsed shape
669  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
670  transferReadOp, cast<VectorType>(vector.getType()), flatRead);
671  return success();
672  }
673 
674 private:
675  // Minimum bitwidth that the trailing vector dimension should have after
676  // flattening.
677  unsigned targetVectorBitwidth;
678 };
679 
680 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
681 /// memref.collapse_shape on the source so that the resulting
682 /// vector.transfer_write has a 1D source. Requires the source shape to be
683 /// already reduced i.e. without unit dims.
684 ///
685 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
686 /// the trailing dimension of the vector read is smaller than the provided
687 /// bitwidth.
688 class FlattenContiguousRowMajorTransferWritePattern
689  : public OpRewritePattern<vector::TransferWriteOp> {
690 public:
691  FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
692  unsigned vectorBitwidth,
693  PatternBenefit benefit)
694  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
695  targetVectorBitwidth(vectorBitwidth) {}
696 
697  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
698  PatternRewriter &rewriter) const override {
699  auto loc = transferWriteOp.getLoc();
700  Value vector = transferWriteOp.getVector();
701  VectorType vectorType = cast<VectorType>(vector.getType());
702  Value source = transferWriteOp.getBase();
703  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
704 
705  // 0. Check pre-conditions
706  // Contiguity check is valid on tensors only.
707  if (!sourceType)
708  return failure();
709  // If this is already 0D/1D, there's nothing to do.
710  if (vectorType.getRank() <= 1)
711  // Already 0D/1D, nothing to do.
712  return failure();
713  if (!vectorType.getElementType().isSignlessIntOrFloat())
714  return failure();
715  unsigned trailingVectorDimBitwidth =
716  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
717  if (trailingVectorDimBitwidth >= targetVectorBitwidth)
718  return failure();
719  if (!vector::isContiguousSlice(sourceType, vectorType))
720  return failure();
721  // TODO: generalize this pattern, relax the requirements here.
722  if (transferWriteOp.hasOutOfBoundsDim())
723  return failure();
724  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
725  return failure();
726  if (transferWriteOp.getMask())
727  return failure();
728 
729  // Determine the first memref dimension to collapse - just enough so we can
730  // read a flattened vector.
731  int64_t firstDimToCollapse =
732  sourceType.getRank() -
733  vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
734 
735  // 1. Collapse the source memref
736  Value collapsedSource =
737  collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
738  MemRefType collapsedSourceType =
739  cast<MemRefType>(collapsedSource.getType());
740  int64_t collapsedRank = collapsedSourceType.getRank();
741  assert(collapsedRank == firstDimToCollapse + 1);
742 
743  // 2. Generate input args for a new vector.transfer_read that will read
744  // from the collapsed memref.
745  // 2.1. New dim exprs + affine map
747  getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
748  auto collapsedMap =
749  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
750 
751  // 2.2 New indices
752  SmallVector<Value> collapsedIndices =
753  getCollapsedIndices(rewriter, loc, sourceType.getShape(),
754  transferWriteOp.getIndices(), firstDimToCollapse);
755 
756  // 3. Create new vector.transfer_write that writes to the collapsed memref
757  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
758  vectorType.getElementType());
759  Value flatVector =
760  rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
761  vector::TransferWriteOp flatWrite =
762  rewriter.create<vector::TransferWriteOp>(
763  loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
764  flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
765 
766  // 4. Replace the old transfer_write with the new one writing the
767  // collapsed shape
768  rewriter.eraseOp(transferWriteOp);
769  return success();
770  }
771 
772 private:
773  // Minimum bitwidth that the trailing vector dimension should have after
774  // flattening.
775  unsigned targetVectorBitwidth;
776 };
777 
778 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
779 /// to `memref.load` patterns. The `match` method is shared for both
780 /// `vector.extract` and `vector.extract_element`.
781 template <class VectorExtractOp>
782 class RewriteScalarExtractOfTransferReadBase
783  : public OpRewritePattern<VectorExtractOp> {
785 
786 public:
787  RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
788  PatternBenefit benefit,
789  bool allowMultipleUses)
790  : Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
791 
792  LogicalResult match(VectorExtractOp extractOp) const {
793  auto xferOp =
794  extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
795  if (!xferOp)
796  return failure();
797  // Check that we are extracting a scalar and not a sub-vector.
798  if (isa<VectorType>(extractOp.getResult().getType()))
799  return failure();
800  // If multiple uses are not allowed, check if xfer has a single use.
801  if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
802  return failure();
803  // If multiple uses are allowed, check if all the xfer uses are extract ops.
804  if (allowMultipleUses &&
805  !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
806  return isa<vector::ExtractOp, vector::ExtractElementOp>(
807  use.getOwner());
808  }))
809  return failure();
810  // Mask not supported.
811  if (xferOp.getMask())
812  return failure();
813  // Map not supported.
814  if (!xferOp.getPermutationMap().isMinorIdentity())
815  return failure();
816  // Cannot rewrite if the indices may be out of bounds.
817  if (xferOp.hasOutOfBoundsDim())
818  return failure();
819  return success();
820  }
821 
822 private:
823  bool allowMultipleUses;
824 };
825 
826 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
827 ///
828 /// All the users of the transfer op must be either `vector.extractelement` or
829 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
830 /// transfer ops with any number of users. Otherwise, rewrite only if the
831 /// extract op is the single user of the transfer op. Rewriting a single
832 /// vector load with multiple scalar loads may negatively affect performance.
833 class RewriteScalarExtractElementOfTransferRead
834  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
835  using RewriteScalarExtractOfTransferReadBase::
836  RewriteScalarExtractOfTransferReadBase;
837 
838  LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
839  PatternRewriter &rewriter) const override {
840  if (failed(match(extractOp)))
841  return failure();
842 
843  // Construct scalar load.
844  auto loc = extractOp.getLoc();
845  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
846  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
847  xferOp.getIndices().end());
848  if (extractOp.getPosition()) {
849  AffineExpr sym0, sym1;
850  bindSymbols(extractOp.getContext(), sym0, sym1);
852  rewriter, loc, sym0 + sym1,
853  {newIndices[newIndices.size() - 1], extractOp.getPosition()});
854  if (auto value = dyn_cast<Value>(ofr)) {
855  newIndices[newIndices.size() - 1] = value;
856  } else {
857  newIndices[newIndices.size() - 1] =
858  rewriter.create<arith::ConstantIndexOp>(loc,
859  *getConstantIntValue(ofr));
860  }
861  }
862  if (isa<MemRefType>(xferOp.getBase().getType())) {
863  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
864  newIndices);
865  } else {
866  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
867  extractOp, xferOp.getBase(), newIndices);
868  }
869 
870  return success();
871  }
872 };
873 
874 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
875 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
876 ///
877 /// All the users of the transfer op must be either `vector.extractelement` or
878 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
879 /// transfer ops with any number of users. Otherwise, rewrite only if the
880 /// extract op is the single user of the transfer op. Rewriting a single
881 /// vector load with multiple scalar loads may negatively affect performance.
882 class RewriteScalarExtractOfTransferRead
883  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
884  using RewriteScalarExtractOfTransferReadBase::
885  RewriteScalarExtractOfTransferReadBase;
886 
887  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
888  PatternRewriter &rewriter) const override {
889  if (failed(match(extractOp)))
890  return failure();
891 
892  // Construct scalar load.
893  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
894  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
895  xferOp.getIndices().end());
896  for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
897  int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
898 
899  // Compute affine expression `newIndices[idx] + pos` where `pos` can be
900  // either a constant or a value.
901  OpFoldResult composedIdx;
902  if (auto attr = dyn_cast<Attribute>(pos)) {
903  int64_t offset = cast<IntegerAttr>(attr).getInt();
905  rewriter, extractOp.getLoc(),
906  rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
907  } else {
908  Value dynamicOffset = cast<Value>(pos);
909  AffineExpr sym0, sym1;
910  bindSymbols(rewriter.getContext(), sym0, sym1);
912  rewriter, extractOp.getLoc(), sym0 + sym1,
913  {newIndices[idx], dynamicOffset});
914  }
915 
916  // Update the corresponding index with the folded result.
917  if (auto value = dyn_cast<Value>(composedIdx)) {
918  newIndices[idx] = value;
919  } else {
920  newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
921  extractOp.getLoc(), *getConstantIntValue(composedIdx));
922  }
923  }
924  if (isa<MemRefType>(xferOp.getBase().getType())) {
925  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
926  newIndices);
927  } else {
928  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
929  extractOp, xferOp.getBase(), newIndices);
930  }
931 
932  return success();
933  }
934 };
935 
936 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
937 /// to memref.store.
938 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
940 
941  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
942  PatternRewriter &rewriter) const override {
943  // Must be a scalar write.
944  auto vecType = xferOp.getVectorType();
945  if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
946  return failure();
947  // Mask not supported.
948  if (xferOp.getMask())
949  return failure();
950  // Map not supported.
951  if (!xferOp.getPermutationMap().isMinorIdentity())
952  return failure();
953  // Only float and integer element types are supported.
954  Value scalar =
955  rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
956  // Construct a scalar store.
957  if (isa<MemRefType>(xferOp.getBase().getType())) {
958  rewriter.replaceOpWithNewOp<memref::StoreOp>(
959  xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
960  } else {
961  rewriter.replaceOpWithNewOp<tensor::InsertOp>(
962  xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
963  }
964  return success();
965  }
966 };
967 
968 } // namespace
969 
971  Operation *rootOp) {
972  TransferOptimization opt(rewriter, rootOp);
973  // Run store to load forwarding first since it can expose more dead store
974  // opportunity.
975  rootOp->walk([&](vector::TransferReadOp read) {
976  if (isa<MemRefType>(read.getShapedType()))
977  opt.storeToLoadForwarding(read);
978  });
979  opt.removeDeadOp();
980  rootOp->walk([&](vector::TransferWriteOp write) {
981  if (isa<MemRefType>(write.getShapedType()))
982  opt.deadStoreOp(write);
983  });
984  opt.removeDeadOp();
985 }
986 
989  bool allowMultipleUses) {
990  patterns.add<RewriteScalarExtractElementOfTransferRead,
991  RewriteScalarExtractOfTransferRead>(patterns.getContext(),
992  benefit, allowMultipleUses);
993  patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
994 }
995 
996 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
998  patterns
999  .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
1000  patterns.getContext(), benefit);
1001 }
1002 
1003 void mlir::vector::populateFlattenVectorTransferPatterns(
1004  RewritePatternSet &patterns, unsigned targetVectorBitwidth,
1005  PatternBenefit benefit) {
1006  patterns.add<FlattenContiguousRowMajorTransferReadPattern,
1007  FlattenContiguousRowMajorTransferWritePattern>(
1008  patterns.getContext(), targetVectorBitwidth, benefit);
1009  populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
1010 }
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input)
Creates a rank-reducing memref.subview op that drops unit dims from its input.
static int getReducedRank(ArrayRef< int64_t > shape)
Returns the number of dims that aren't unit dims.
static SmallVector< int64_t > getReducedShape(ArrayRef< OpFoldResult > mixedSizes)
Converts OpFoldResults to int64_t shape without unit dims.
static VectorType trimNonScalableUnitDims(VectorType oldType)
Trims non-scalable one dimensions from oldType and returns the result type.
static FailureOr< Value > createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, vector::CreateMaskOp op)
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
static SmallVector< Value > getCollapsedIndices(RewriterBase &rewriter, Location loc, ArrayRef< int64_t > shape, ValueRange indices, int64_t firstDimToCollapse)
Returns the new indices that collapses the inner dimensions starting from the firstDimToCollapse dime...
#define DBGS()
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
Drops unit dimensions from the input MemRefType.
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
Definition: Block.cpp:353
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:363
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:265
A class for computing basic dominance information.
Definition: Dominance.h:140
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
A class for computing basic postdominance information.
Definition: Dominance.h:204
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1333
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
Definition: MemRefUtils.h:111
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:315
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:161