MLIR  20.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
24 #include "mlir/IR/Dominance.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Debug.h"
29 
30 #define DEBUG_TYPE "vector-transfer-opt"
31 
32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
33 
34 using namespace mlir;
35 
36 /// Return the ancestor op in the region or nullptr if the region is not
37 /// an ancestor of the op.
39  for (; op != nullptr && op->getParentRegion() != region;
40  op = op->getParentOp())
41  ;
42  return op;
43 }
44 
45 namespace {
46 
47 class TransferOptimization {
48 public:
49  TransferOptimization(RewriterBase &rewriter, Operation *op)
50  : rewriter(rewriter), dominators(op), postDominators(op) {}
51  void deadStoreOp(vector::TransferWriteOp);
52  void storeToLoadForwarding(vector::TransferReadOp);
53  void removeDeadOp() {
54  for (Operation *op : opToErase)
55  rewriter.eraseOp(op);
56  opToErase.clear();
57  }
58 
59 private:
60  RewriterBase &rewriter;
61  bool isReachable(Operation *start, Operation *dest);
62  DominanceInfo dominators;
63  PostDominanceInfo postDominators;
64  std::vector<Operation *> opToErase;
65 };
66 
67 } // namespace
68 /// Return true if there is a path from start operation to dest operation,
69 /// otherwise return false. The operations have to be in the same region.
71  assert(start->getParentRegion() == dest->getParentRegion() &&
72  "This function only works for ops i the same region");
73  // Simple case where the start op dominate the destination.
74  if (dominators.dominates(start, dest))
75  return true;
76  Block *startBlock = start->getBlock();
77  Block *destBlock = dest->getBlock();
78  SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
79  startBlock->succ_end());
81  while (!worklist.empty()) {
82  Block *bb = worklist.pop_back_val();
83  if (!visited.insert(bb).second)
84  continue;
85  if (dominators.dominates(bb, destBlock))
86  return true;
87  worklist.append(bb->succ_begin(), bb->succ_end());
88  }
89  return false;
90 }
91 
92 /// For transfer_write to overwrite fully another transfer_write must:
93 /// 1. Access the same memref with the same indices and vector type.
94 /// 2. Post-dominate the other transfer_write operation.
95 /// If several candidates are available, one must be post-dominated by all the
96 /// others since they are all post-dominating the same transfer_write. We only
97 /// consider the transfer_write post-dominated by all the other candidates as
98 /// this will be the first transfer_write executed after the potentially dead
99 /// transfer_write.
100 /// If we found such an overwriting transfer_write we know that the original
101 /// transfer_write is dead if all reads that can be reached from the potentially
102 /// dead transfer_write are dominated by the overwriting transfer_write.
103 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
104  LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
105  << "\n");
106  llvm::SmallVector<Operation *, 8> blockingAccesses;
107  Operation *firstOverwriteCandidate = nullptr;
108  Value source =
109  memref::skipSubViewsAndCasts(cast<MemrefValue>(write.getSource()));
110  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
111  source.getUsers().end());
112  llvm::SmallDenseSet<Operation *, 32> processed;
113  while (!users.empty()) {
114  Operation *user = users.pop_back_val();
115  // If the user has already been processed skip.
116  if (!processed.insert(user).second)
117  continue;
118  if (isa<memref::SubViewOp, memref::CastOp>(user)) {
119  users.append(user->getUsers().begin(), user->getUsers().end());
120  continue;
121  }
122  if (isMemoryEffectFree(user))
123  continue;
124  if (user == write.getOperation())
125  continue;
126  if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
127  // Check candidate that can override the store.
129  cast<MemrefValue>(nextWrite.getSource()),
130  cast<MemrefValue>(write.getSource())) &&
131  checkSameValueWAW(nextWrite, write) &&
132  postDominators.postDominates(nextWrite, write)) {
133  if (firstOverwriteCandidate == nullptr ||
134  postDominators.postDominates(firstOverwriteCandidate, nextWrite))
135  firstOverwriteCandidate = nextWrite;
136  else
137  assert(
138  postDominators.postDominates(nextWrite, firstOverwriteCandidate));
139  continue;
140  }
141  }
142  if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
143  // Don't need to consider disjoint accesses.
145  cast<VectorTransferOpInterface>(write.getOperation()),
146  cast<VectorTransferOpInterface>(transferOp.getOperation()),
147  /*testDynamicValueUsingBounds=*/true))
148  continue;
149  }
150  blockingAccesses.push_back(user);
151  }
152  if (firstOverwriteCandidate == nullptr)
153  return;
154  Region *topRegion = firstOverwriteCandidate->getParentRegion();
155  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
156  assert(writeAncestor &&
157  "write op should be recursively part of the top region");
158 
159  for (Operation *access : blockingAccesses) {
160  Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
161  // TODO: if the access and write have the same ancestor we could recurse in
162  // the region to know if the access is reachable with more precision.
163  if (accessAncestor == nullptr ||
164  !isReachable(writeAncestor, accessAncestor))
165  continue;
166  if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
167  LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
168  << *accessAncestor << "\n");
169  return;
170  }
171  }
172  LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
173  << " overwritten by: " << *firstOverwriteCandidate << "\n");
174  opToErase.push_back(write.getOperation());
175 }
176 
177 /// A transfer_write candidate to storeToLoad forwarding must:
178 /// 1. Access the same memref with the same indices and vector type as the
179 /// transfer_read.
180 /// 2. Dominate the transfer_read operation.
181 /// If several candidates are available, one must be dominated by all the others
182 /// since they are all dominating the same transfer_read. We only consider the
183 /// transfer_write dominated by all the other candidates as this will be the
184 /// last transfer_write executed before the transfer_read.
185 /// If we found such a candidate we can do the forwarding if all the other
186 /// potentially aliasing ops that may reach the transfer_read are post-dominated
187 /// by the transfer_write.
188 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
189  if (read.hasOutOfBoundsDim())
190  return;
191  LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
192  << "\n");
193  SmallVector<Operation *, 8> blockingWrites;
194  vector::TransferWriteOp lastwrite = nullptr;
195  Value source =
196  memref::skipSubViewsAndCasts(cast<MemrefValue>(read.getSource()));
197  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
198  source.getUsers().end());
199  llvm::SmallDenseSet<Operation *, 32> processed;
200  while (!users.empty()) {
201  Operation *user = users.pop_back_val();
202  // If the user has already been processed skip.
203  if (!processed.insert(user).second)
204  continue;
205  if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
206  users.append(user->getUsers().begin(), user->getUsers().end());
207  continue;
208  }
209  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
210  continue;
211  if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
212  // If there is a write, but we can prove that it is disjoint we can ignore
213  // the write.
215  cast<VectorTransferOpInterface>(write.getOperation()),
216  cast<VectorTransferOpInterface>(read.getOperation()),
217  /*testDynamicValueUsingBounds=*/true))
218  continue;
220  cast<MemrefValue>(read.getSource()),
221  cast<MemrefValue>(write.getSource())) &&
222  dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
223  if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
224  lastwrite = write;
225  else
226  assert(dominators.dominates(write, lastwrite));
227  continue;
228  }
229  }
230  blockingWrites.push_back(user);
231  }
232 
233  if (lastwrite == nullptr)
234  return;
235 
236  Region *topRegion = lastwrite->getParentRegion();
237  Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
238  assert(readAncestor &&
239  "read op should be recursively part of the top region");
240 
241  for (Operation *write : blockingWrites) {
242  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
243  // TODO: if the store and read have the same ancestor we could recurse in
244  // the region to know if the read is reachable with more precision.
245  if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
246  continue;
247  if (!postDominators.postDominates(lastwrite, write)) {
248  LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
249  << *write << "\n");
250  return;
251  }
252  }
253 
254  LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
255  << " to: " << *read.getOperation() << "\n");
256  read.replaceAllUsesWith(lastwrite.getVector());
257  opToErase.push_back(read.getOperation());
258 }
259 
260 /// Converts OpFoldResults to int64_t shape without unit dims.
262  SmallVector<int64_t> reducedShape;
263  for (const auto size : mixedSizes) {
264  if (llvm::dyn_cast_if_present<Value>(size)) {
265  reducedShape.push_back(ShapedType::kDynamic);
266  continue;
267  }
268 
269  auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
270  if (value == 1)
271  continue;
272  reducedShape.push_back(value.getSExtValue());
273  }
274  return reducedShape;
275 }
276 
277 /// Drops unit dimensions from the input MemRefType.
278 static MemRefType dropUnitDims(MemRefType inputType,
279  ArrayRef<OpFoldResult> offsets,
281  ArrayRef<OpFoldResult> strides) {
282  auto targetShape = getReducedShape(sizes);
283  Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
284  targetShape, inputType, offsets, sizes, strides);
285  return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
286 }
287 
288 /// Creates a rank-reducing memref.subview op that drops unit dims from its
289 /// input. Or just returns the input if it was already without unit dims.
291  mlir::Location loc,
292  Value input) {
293  MemRefType inputType = cast<MemRefType>(input.getType());
294  SmallVector<OpFoldResult> offsets(inputType.getRank(),
295  rewriter.getIndexAttr(0));
296  SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
297  SmallVector<OpFoldResult> strides(inputType.getRank(),
298  rewriter.getIndexAttr(1));
299  MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
300 
301  if (canonicalizeStridedLayout(resultType) ==
302  canonicalizeStridedLayout(inputType))
303  return input;
304  return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
305  sizes, strides);
306 }
307 
308 /// Returns the number of dims that aren't unit dims.
310  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
311 }
312 
313 /// Trims non-scalable one dimensions from `oldType` and returns the result
314 /// type.
315 static VectorType trimNonScalableUnitDims(VectorType oldType) {
316  SmallVector<int64_t> newShape;
317  SmallVector<bool> newScalableDims;
318  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
319  if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
320  continue;
321  newShape.push_back(dimSize);
322  newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
323  }
324  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
325 }
326 
327 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
328 static FailureOr<Value>
330  vector::CreateMaskOp op) {
331  auto type = op.getType();
332  VectorType reducedType = trimNonScalableUnitDims(type);
333  if (reducedType.getRank() == type.getRank())
334  return failure();
335 
336  SmallVector<Value> reducedOperands;
337  for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
338  type.getShape(), type.getScalableDims(), op.getOperands())) {
339  if (dim == 1 && !dimIsScalable) {
340  // If the mask for the unit dim is not a constant of 1, do nothing.
341  auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
342  if (!constant || (constant.value() != 1))
343  return failure();
344  continue;
345  }
346  reducedOperands.push_back(operand);
347  }
348  return rewriter
349  .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
350  .getResult();
351 }
352 
353 namespace {
354 
355 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
356 /// inserting a memref.subview dropping those unit dims. The vector shapes are
357 /// also reduced accordingly.
358 class TransferReadDropUnitDimsPattern
359  : public OpRewritePattern<vector::TransferReadOp> {
361 
362  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
363  PatternRewriter &rewriter) const override {
364  auto loc = transferReadOp.getLoc();
365  Value vector = transferReadOp.getVector();
366  VectorType vectorType = cast<VectorType>(vector.getType());
367  Value source = transferReadOp.getSource();
368  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
369  // TODO: support tensor types.
370  if (!sourceType)
371  return failure();
372  // TODO: generalize this pattern, relax the requirements here.
373  if (transferReadOp.hasOutOfBoundsDim())
374  return failure();
375  if (!transferReadOp.getPermutationMap().isMinorIdentity())
376  return failure();
377  // Check if the source shape can be further reduced.
378  int reducedRank = getReducedRank(sourceType.getShape());
379  if (reducedRank == sourceType.getRank())
380  return failure();
381  // Check if the reduced vector shape matches the reduced source shape.
382  // Otherwise, this case is not supported yet.
383  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
384  if (reducedRank != reducedVectorType.getRank())
385  return failure();
386  if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
387  return getConstantIntValue(v) != static_cast<int64_t>(0);
388  }))
389  return failure();
390 
391  Value maskOp = transferReadOp.getMask();
392  if (maskOp) {
393  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
394  if (!createMaskOp)
395  return rewriter.notifyMatchFailure(
396  transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
397  "currently supported");
398  FailureOr<Value> rankReducedCreateMask =
399  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
400  if (failed(rankReducedCreateMask))
401  return failure();
402  maskOp = *rankReducedCreateMask;
403  }
404 
405  Value reducedShapeSource =
406  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
407  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
408  SmallVector<Value> zeros(reducedRank, c0);
409  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
410  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
411  auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
412  loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
413  transferReadOp.getPadding(), maskOp,
414  rewriter.getBoolArrayAttr(inBounds));
415  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
416  loc, vectorType, newTransferReadOp);
417  rewriter.replaceOp(transferReadOp, shapeCast);
418 
419  return success();
420  }
421 };
422 
423 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
424 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
425 /// vector shapes are also reduced accordingly.
426 class TransferWriteDropUnitDimsPattern
427  : public OpRewritePattern<vector::TransferWriteOp> {
429 
430  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
431  PatternRewriter &rewriter) const override {
432  auto loc = transferWriteOp.getLoc();
433  Value vector = transferWriteOp.getVector();
434  VectorType vectorType = cast<VectorType>(vector.getType());
435  Value source = transferWriteOp.getSource();
436  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
437  // TODO: support tensor type.
438  if (!sourceType)
439  return failure();
440  // TODO: generalize this pattern, relax the requirements here.
441  if (transferWriteOp.hasOutOfBoundsDim())
442  return failure();
443  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
444  return failure();
445  // Check if the destination shape can be further reduced.
446  int reducedRank = getReducedRank(sourceType.getShape());
447  if (reducedRank == sourceType.getRank())
448  return failure();
449  // Check if the reduced vector shape matches the reduced destination shape.
450  // Otherwise, this case is not supported yet.
451  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
452  if (reducedRank != reducedVectorType.getRank())
453  return failure();
454  if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
455  return getConstantIntValue(v) != static_cast<int64_t>(0);
456  }))
457  return failure();
458 
459  Value maskOp = transferWriteOp.getMask();
460  if (maskOp) {
461  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
462  if (!createMaskOp)
463  return rewriter.notifyMatchFailure(
464  transferWriteOp,
465  "unsupported mask op, only 'vector.create_mask' is "
466  "currently supported");
467  FailureOr<Value> rankReducedCreateMask =
468  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
469  if (failed(rankReducedCreateMask))
470  return failure();
471  maskOp = *rankReducedCreateMask;
472  }
473  Value reducedShapeSource =
474  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
475  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
476  SmallVector<Value> zeros(reducedRank, c0);
477  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
478  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
479  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
480  loc, reducedVectorType, vector);
481  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
482  transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
483  identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
484 
485  return success();
486  }
487 };
488 
489 } // namespace
490 
491 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
492 /// input starting at `firstDimToCollapse`.
494  Value input, int64_t firstDimToCollapse) {
495  ShapedType inputType = cast<ShapedType>(input.getType());
496  if (inputType.getRank() == 1)
497  return input;
498  SmallVector<ReassociationIndices> reassociation;
499  for (int64_t i = 0; i < firstDimToCollapse; ++i)
500  reassociation.push_back(ReassociationIndices{i});
501  ReassociationIndices collapsedIndices;
502  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
503  collapsedIndices.push_back(i);
504  reassociation.push_back(collapsedIndices);
505  return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
506 }
507 
508 /// Returns the new indices that collapses the inner dimensions starting from
509 /// the `firstDimToCollapse` dimension.
511  Location loc,
512  ArrayRef<int64_t> shape,
513  ValueRange indices,
514  int64_t firstDimToCollapse) {
515  assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
516 
517  // If all the collapsed indices are zero then no extra logic is needed.
518  // Otherwise, a new offset/index has to be computed.
519  SmallVector<Value> indicesAfterCollapsing(
520  indices.begin(), indices.begin() + firstDimToCollapse);
521  SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
522  indices.end());
523  if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
524  indicesAfterCollapsing.push_back(indicesToCollapse[0]);
525  return indicesAfterCollapsing;
526  }
527 
528  // Compute the remaining trailing index/offset required for reading from
529  // the collapsed memref:
530  //
531  // offset = 0
532  // for (i = firstDimToCollapse; i < outputRank; ++i)
533  // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
534  //
535  // For this example:
536  // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
537  // memref<1x43x2xi32>, vector<1x2xi32>
538  // which would be collapsed to:
539  // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
540  // memref<1x86xi32>, vector<2xi32>
541  // one would get the following offset:
542  // %offset = %arg0 * 43
543  OpFoldResult collapsedOffset =
544  rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
545 
546  auto collapsedStrides = computeSuffixProduct(
547  ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
548 
549  // Compute the collapsed offset.
550  auto &&[collapsedExpr, collapsedVals] =
551  computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
552  collapsedOffset = affine::makeComposedFoldedAffineApply(
553  rewriter, loc, collapsedExpr, collapsedVals);
554 
555  if (collapsedOffset.is<Value>()) {
556  indicesAfterCollapsing.push_back(collapsedOffset.get<Value>());
557  } else {
558  indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
559  loc, *getConstantIntValue(collapsedOffset)));
560  }
561 
562  return indicesAfterCollapsing;
563 }
564 
565 namespace {
566 
567 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
568 /// memref.collapse_shape on the source so that the resulting
569 /// vector.transfer_read has a 1D source. Requires the source shape to be
570 /// already reduced i.e. without unit dims.
571 ///
572 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
573 /// the trailing dimension of the vector read is smaller than the provided
574 /// bitwidth.
575 class FlattenContiguousRowMajorTransferReadPattern
576  : public OpRewritePattern<vector::TransferReadOp> {
577 public:
578  FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
579  unsigned vectorBitwidth,
580  PatternBenefit benefit)
581  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
582  targetVectorBitwidth(vectorBitwidth) {}
583 
584  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
585  PatternRewriter &rewriter) const override {
586  auto loc = transferReadOp.getLoc();
587  Value vector = transferReadOp.getVector();
588  VectorType vectorType = cast<VectorType>(vector.getType());
589  auto source = transferReadOp.getSource();
590  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
591 
592  // 0. Check pre-conditions
593  // Contiguity check is valid on tensors only.
594  if (!sourceType)
595  return failure();
596  // If this is already 0D/1D, there's nothing to do.
597  if (vectorType.getRank() <= 1)
598  return failure();
599  if (!vectorType.getElementType().isSignlessIntOrFloat())
600  return failure();
601  unsigned trailingVectorDimBitwidth =
602  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
603  if (trailingVectorDimBitwidth >= targetVectorBitwidth)
604  return failure();
605  if (!vector::isContiguousSlice(sourceType, vectorType))
606  return failure();
607  // TODO: generalize this pattern, relax the requirements here.
608  if (transferReadOp.hasOutOfBoundsDim())
609  return failure();
610  if (!transferReadOp.getPermutationMap().isMinorIdentity())
611  return failure();
612  if (transferReadOp.getMask())
613  return failure();
614 
615  int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
616 
617  // 1. Collapse the source memref
618  Value collapsedSource =
619  collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
620  MemRefType collapsedSourceType =
621  cast<MemRefType>(collapsedSource.getType());
622  int64_t collapsedRank = collapsedSourceType.getRank();
623  assert(collapsedRank == firstDimToCollapse + 1);
624 
625  // 2. Generate input args for a new vector.transfer_read that will read
626  // from the collapsed memref.
627  // 2.1. New dim exprs + affine map
629  getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
630  auto collapsedMap =
631  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
632 
633  // 2.2 New indices
634  SmallVector<Value> collapsedIndices =
635  getCollapsedIndices(rewriter, loc, sourceType.getShape(),
636  transferReadOp.getIndices(), firstDimToCollapse);
637 
638  // 3. Create new vector.transfer_read that reads from the collapsed memref
639  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
640  vectorType.getElementType());
641  vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
642  loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
643  flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
644 
645  // 4. Replace the old transfer_read with the new one reading from the
646  // collapsed shape
647  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
648  transferReadOp, cast<VectorType>(vector.getType()), flatRead);
649  return success();
650  }
651 
652 private:
653  // Minimum bitwidth that the trailing vector dimension should have after
654  // flattening.
655  unsigned targetVectorBitwidth;
656 };
657 
658 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
659 /// memref.collapse_shape on the source so that the resulting
660 /// vector.transfer_write has a 1D source. Requires the source shape to be
661 /// already reduced i.e. without unit dims.
662 ///
663 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
664 /// the trailing dimension of the vector read is smaller than the provided
665 /// bitwidth.
666 class FlattenContiguousRowMajorTransferWritePattern
667  : public OpRewritePattern<vector::TransferWriteOp> {
668 public:
669  FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
670  unsigned vectorBitwidth,
671  PatternBenefit benefit)
672  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
673  targetVectorBitwidth(vectorBitwidth) {}
674 
675  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
676  PatternRewriter &rewriter) const override {
677  auto loc = transferWriteOp.getLoc();
678  Value vector = transferWriteOp.getVector();
679  VectorType vectorType = cast<VectorType>(vector.getType());
680  Value source = transferWriteOp.getSource();
681  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
682 
683  // 0. Check pre-conditions
684  // Contiguity check is valid on tensors only.
685  if (!sourceType)
686  return failure();
687  // If this is already 0D/1D, there's nothing to do.
688  if (vectorType.getRank() <= 1)
689  // Already 0D/1D, nothing to do.
690  return failure();
691  if (!vectorType.getElementType().isSignlessIntOrFloat())
692  return failure();
693  unsigned trailingVectorDimBitwidth =
694  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
695  if (trailingVectorDimBitwidth >= targetVectorBitwidth)
696  return failure();
697  if (!vector::isContiguousSlice(sourceType, vectorType))
698  return failure();
699  // TODO: generalize this pattern, relax the requirements here.
700  if (transferWriteOp.hasOutOfBoundsDim())
701  return failure();
702  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
703  return failure();
704  if (transferWriteOp.getMask())
705  return failure();
706 
707  int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
708 
709  // 1. Collapse the source memref
710  Value collapsedSource =
711  collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
712  MemRefType collapsedSourceType =
713  cast<MemRefType>(collapsedSource.getType());
714  int64_t collapsedRank = collapsedSourceType.getRank();
715  assert(collapsedRank == firstDimToCollapse + 1);
716 
717  // 2. Generate input args for a new vector.transfer_read that will read
718  // from the collapsed memref.
719  // 2.1. New dim exprs + affine map
721  getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
722  auto collapsedMap =
723  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
724 
725  // 2.2 New indices
726  SmallVector<Value> collapsedIndices =
727  getCollapsedIndices(rewriter, loc, sourceType.getShape(),
728  transferWriteOp.getIndices(), firstDimToCollapse);
729 
730  // 3. Create new vector.transfer_write that writes to the collapsed memref
731  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
732  vectorType.getElementType());
733  Value flatVector =
734  rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
735  vector::TransferWriteOp flatWrite =
736  rewriter.create<vector::TransferWriteOp>(
737  loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
738  flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
739 
740  // 4. Replace the old transfer_write with the new one writing the
741  // collapsed shape
742  rewriter.eraseOp(transferWriteOp);
743  return success();
744  }
745 
746 private:
747  // Minimum bitwidth that the trailing vector dimension should have after
748  // flattening.
749  unsigned targetVectorBitwidth;
750 };
751 
752 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
753 /// to `memref.load` patterns. The `match` method is shared for both
754 /// `vector.extract` and `vector.extract_element`.
755 template <class VectorExtractOp>
756 class RewriteScalarExtractOfTransferReadBase
757  : public OpRewritePattern<VectorExtractOp> {
759 
760 public:
761  RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
762  PatternBenefit benefit,
763  bool allowMultipleUses)
764  : Base::OpRewritePattern(context, benefit),
765  allowMultipleUses(allowMultipleUses) {}
766 
767  LogicalResult match(VectorExtractOp extractOp) const override {
768  auto xferOp =
769  extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
770  if (!xferOp)
771  return failure();
772  // Check that we are extracting a scalar and not a sub-vector.
773  if (isa<VectorType>(extractOp.getResult().getType()))
774  return failure();
775  // If multiple uses are not allowed, check if xfer has a single use.
776  if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
777  return failure();
778  // If multiple uses are allowed, check if all the xfer uses are extract ops.
779  if (allowMultipleUses &&
780  !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
781  return isa<vector::ExtractOp, vector::ExtractElementOp>(
782  use.getOwner());
783  }))
784  return failure();
785  // Mask not supported.
786  if (xferOp.getMask())
787  return failure();
788  // Map not supported.
789  if (!xferOp.getPermutationMap().isMinorIdentity())
790  return failure();
791  // Cannot rewrite if the indices may be out of bounds.
792  if (xferOp.hasOutOfBoundsDim())
793  return failure();
794  return success();
795  }
796 
797 private:
798  bool allowMultipleUses;
799 };
800 
801 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
802 ///
803 /// All the users of the transfer op must be either `vector.extractelement` or
804 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
805 /// transfer ops with any number of users. Otherwise, rewrite only if the
806 /// extract op is the single user of the transfer op. Rewriting a single
807 /// vector load with multiple scalar loads may negatively affect performance.
808 class RewriteScalarExtractElementOfTransferRead
809  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
810  using RewriteScalarExtractOfTransferReadBase::
811  RewriteScalarExtractOfTransferReadBase;
812 
813  void rewrite(vector::ExtractElementOp extractOp,
814  PatternRewriter &rewriter) const override {
815  // Construct scalar load.
816  auto loc = extractOp.getLoc();
817  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
818  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
819  xferOp.getIndices().end());
820  if (extractOp.getPosition()) {
821  AffineExpr sym0, sym1;
822  bindSymbols(extractOp.getContext(), sym0, sym1);
824  rewriter, loc, sym0 + sym1,
825  {newIndices[newIndices.size() - 1], extractOp.getPosition()});
826  if (ofr.is<Value>()) {
827  newIndices[newIndices.size() - 1] = ofr.get<Value>();
828  } else {
829  newIndices[newIndices.size() - 1] =
830  rewriter.create<arith::ConstantIndexOp>(loc,
831  *getConstantIntValue(ofr));
832  }
833  }
834  if (isa<MemRefType>(xferOp.getSource().getType())) {
835  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
836  newIndices);
837  } else {
838  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
839  extractOp, xferOp.getSource(), newIndices);
840  }
841  }
842 };
843 
844 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
845 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
846 ///
847 /// All the users of the transfer op must be either `vector.extractelement` or
848 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
849 /// transfer ops with any number of users. Otherwise, rewrite only if the
850 /// extract op is the single user of the transfer op. Rewriting a single
851 /// vector load with multiple scalar loads may negatively affect performance.
852 class RewriteScalarExtractOfTransferRead
853  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
854  using RewriteScalarExtractOfTransferReadBase::
855  RewriteScalarExtractOfTransferReadBase;
856 
857  void rewrite(vector::ExtractOp extractOp,
858  PatternRewriter &rewriter) const override {
859  // Construct scalar load.
860  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
861  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
862  xferOp.getIndices().end());
863  for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
864  assert(pos.is<Attribute>() && "Unexpected non-constant index");
865  int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
866  int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
868  rewriter, extractOp.getLoc(),
869  rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
870  if (ofr.is<Value>()) {
871  newIndices[idx] = ofr.get<Value>();
872  } else {
873  newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
874  extractOp.getLoc(), *getConstantIntValue(ofr));
875  }
876  }
877  if (isa<MemRefType>(xferOp.getSource().getType())) {
878  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
879  newIndices);
880  } else {
881  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
882  extractOp, xferOp.getSource(), newIndices);
883  }
884  }
885 };
886 
887 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
888 /// to memref.store.
889 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
891 
892  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
893  PatternRewriter &rewriter) const override {
894  // Must be a scalar write.
895  auto vecType = xferOp.getVectorType();
896  if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
897  return failure();
898  // Mask not supported.
899  if (xferOp.getMask())
900  return failure();
901  // Map not supported.
902  if (!xferOp.getPermutationMap().isMinorIdentity())
903  return failure();
904  // Only float and integer element types are supported.
905  Value scalar;
906  if (vecType.getRank() == 0) {
907  // vector.extract does not support vector<f32> etc., so use
908  // vector.extractelement instead.
909  scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
910  xferOp.getVector());
911  } else {
912  SmallVector<int64_t> pos(vecType.getRank(), 0);
913  scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
914  xferOp.getVector(), pos);
915  }
916  // Construct a scalar store.
917  if (isa<MemRefType>(xferOp.getSource().getType())) {
918  rewriter.replaceOpWithNewOp<memref::StoreOp>(
919  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
920  } else {
921  rewriter.replaceOpWithNewOp<tensor::InsertOp>(
922  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
923  }
924  return success();
925  }
926 };
927 
928 } // namespace
929 
931  Operation *rootOp) {
932  TransferOptimization opt(rewriter, rootOp);
933  // Run store to load forwarding first since it can expose more dead store
934  // opportunity.
935  rootOp->walk([&](vector::TransferReadOp read) {
936  if (isa<MemRefType>(read.getShapedType()))
937  opt.storeToLoadForwarding(read);
938  });
939  opt.removeDeadOp();
940  rootOp->walk([&](vector::TransferWriteOp write) {
941  if (isa<MemRefType>(write.getShapedType()))
942  opt.deadStoreOp(write);
943  });
944  opt.removeDeadOp();
945 }
946 
948  RewritePatternSet &patterns, PatternBenefit benefit,
949  bool allowMultipleUses) {
950  patterns.add<RewriteScalarExtractElementOfTransferRead,
951  RewriteScalarExtractOfTransferRead>(patterns.getContext(),
952  benefit, allowMultipleUses);
953  patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
954 }
955 
957  RewritePatternSet &patterns, PatternBenefit benefit) {
958  patterns
959  .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
960  patterns.getContext(), benefit);
962 }
963 
965  RewritePatternSet &patterns, unsigned targetVectorBitwidth,
966  PatternBenefit benefit) {
967  patterns.add<FlattenContiguousRowMajorTransferReadPattern,
968  FlattenContiguousRowMajorTransferWritePattern>(
969  patterns.getContext(), targetVectorBitwidth, benefit);
970  populateShapeCastFoldingPatterns(patterns, benefit);
971  populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
972 }
static bool isReachable(Block *from, Block *to, ArrayRef< Block * > except)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input)
Creates a rank-reducing memref.subview op that drops unit dims from its input.
static int getReducedRank(ArrayRef< int64_t > shape)
Returns the number of dims that aren't unit dims.
static SmallVector< int64_t > getReducedShape(ArrayRef< OpFoldResult > mixedSizes)
Converts OpFoldResults to int64_t shape without unit dims.
static VectorType trimNonScalableUnitDims(VectorType oldType)
Trims non-scalable one dimensions from oldType and returns the result type.
static FailureOr< Value > createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, vector::CreateMaskOp op)
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
static SmallVector< Value > getCollapsedIndices(RewriterBase &rewriter, Location loc, ArrayRef< int64_t > shape, ValueRange indices, int64_t firstDimToCollapse)
Returns the new indices that collapses the inner dimensions starting from the firstDimToCollapse dime...
#define DBGS()
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
Drops unit dimensions from the input MemRefType.
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
succ_iterator succ_end()
Definition: Block.h:264
succ_iterator succ_begin()
Definition: Block.h:263
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:398
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:379
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:281
A class for computing basic dominance information.
Definition: Dominance.h:136
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:523
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
A class for computing basic postdominance information.
Definition: Dominance.h:195
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or are statically known to alias the same region of memory...
Definition: MemRefUtils.h:105
MemrefValue skipSubViewsAndCasts(MemrefValue source)
Walk up the source chain until something an op other than a memref.subview or memref....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, unsigned targetVectorBitwidth=std::numeric_limits< unsigned >::max(), PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:284
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:607
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362