MLIR  19.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
23 #include "mlir/IR/Dominance.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Debug.h"
28 
29 #define DEBUG_TYPE "vector-transfer-opt"
30 
31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
32 
33 using namespace mlir;
34 
35 /// Return the ancestor op in the region or nullptr if the region is not
36 /// an ancestor of the op.
38  for (; op != nullptr && op->getParentRegion() != region;
39  op = op->getParentOp())
40  ;
41  return op;
42 }
43 
44 namespace {
45 
46 class TransferOptimization {
47 public:
48  TransferOptimization(RewriterBase &rewriter, Operation *op)
49  : rewriter(rewriter), dominators(op), postDominators(op) {}
50  void deadStoreOp(vector::TransferWriteOp);
51  void storeToLoadForwarding(vector::TransferReadOp);
52  void removeDeadOp() {
53  for (Operation *op : opToErase)
54  rewriter.eraseOp(op);
55  opToErase.clear();
56  }
57 
58 private:
59  RewriterBase &rewriter;
60  bool isReachable(Operation *start, Operation *dest);
61  DominanceInfo dominators;
62  PostDominanceInfo postDominators;
63  std::vector<Operation *> opToErase;
64 };
65 
66 } // namespace
67 /// Return true if there is a path from start operation to dest operation,
68 /// otherwise return false. The operations have to be in the same region.
70  assert(start->getParentRegion() == dest->getParentRegion() &&
71  "This function only works for ops i the same region");
72  // Simple case where the start op dominate the destination.
73  if (dominators.dominates(start, dest))
74  return true;
75  Block *startBlock = start->getBlock();
76  Block *destBlock = dest->getBlock();
77  SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
78  startBlock->succ_end());
80  while (!worklist.empty()) {
81  Block *bb = worklist.pop_back_val();
82  if (!visited.insert(bb).second)
83  continue;
84  if (dominators.dominates(bb, destBlock))
85  return true;
86  worklist.append(bb->succ_begin(), bb->succ_end());
87  }
88  return false;
89 }
90 
91 /// For transfer_write to overwrite fully another transfer_write must:
92 /// 1. Access the same memref with the same indices and vector type.
93 /// 2. Post-dominate the other transfer_write operation.
94 /// If several candidates are available, one must be post-dominated by all the
95 /// others since they are all post-dominating the same transfer_write. We only
96 /// consider the transfer_write post-dominated by all the other candidates as
97 /// this will be the first transfer_write executed after the potentially dead
98 /// transfer_write.
99 /// If we found such an overwriting transfer_write we know that the original
100 /// transfer_write is dead if all reads that can be reached from the potentially
101 /// dead transfer_write are dominated by the overwriting transfer_write.
102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
103  LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
104  << "\n");
105  llvm::SmallVector<Operation *, 8> blockingAccesses;
106  Operation *firstOverwriteCandidate = nullptr;
107  Value source = write.getSource();
108  // Skip subview ops.
109  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
110  source = subView.getSource();
111  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
112  source.getUsers().end());
113  llvm::SmallDenseSet<Operation *, 32> processed;
114  while (!users.empty()) {
115  Operation *user = users.pop_back_val();
116  // If the user has already been processed skip.
117  if (!processed.insert(user).second)
118  continue;
119  if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
120  users.append(subView->getUsers().begin(), subView->getUsers().end());
121  continue;
122  }
123  if (isMemoryEffectFree(user))
124  continue;
125  if (user == write.getOperation())
126  continue;
127  if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
128  // Check candidate that can override the store.
129  if (write.getSource() == nextWrite.getSource() &&
130  checkSameValueWAW(nextWrite, write) &&
131  postDominators.postDominates(nextWrite, write)) {
132  if (firstOverwriteCandidate == nullptr ||
133  postDominators.postDominates(firstOverwriteCandidate, nextWrite))
134  firstOverwriteCandidate = nextWrite;
135  else
136  assert(
137  postDominators.postDominates(nextWrite, firstOverwriteCandidate));
138  continue;
139  }
140  }
141  if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
142  // Don't need to consider disjoint accesses.
144  cast<VectorTransferOpInterface>(write.getOperation()),
145  cast<VectorTransferOpInterface>(transferOp.getOperation()),
146  /*testDynamicValueUsingBounds=*/true))
147  continue;
148  }
149  blockingAccesses.push_back(user);
150  }
151  if (firstOverwriteCandidate == nullptr)
152  return;
153  Region *topRegion = firstOverwriteCandidate->getParentRegion();
154  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
155  assert(writeAncestor &&
156  "write op should be recursively part of the top region");
157 
158  for (Operation *access : blockingAccesses) {
159  Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
160  // TODO: if the access and write have the same ancestor we could recurse in
161  // the region to know if the access is reachable with more precision.
162  if (accessAncestor == nullptr ||
163  !isReachable(writeAncestor, accessAncestor))
164  continue;
165  if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
166  LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
167  << *accessAncestor << "\n");
168  return;
169  }
170  }
171  LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
172  << " overwritten by: " << *firstOverwriteCandidate << "\n");
173  opToErase.push_back(write.getOperation());
174 }
175 
176 /// A transfer_write candidate to storeToLoad forwarding must:
177 /// 1. Access the same memref with the same indices and vector type as the
178 /// transfer_read.
179 /// 2. Dominate the transfer_read operation.
180 /// If several candidates are available, one must be dominated by all the others
181 /// since they are all dominating the same transfer_read. We only consider the
182 /// transfer_write dominated by all the other candidates as this will be the
183 /// last transfer_write executed before the transfer_read.
184 /// If we found such a candidate we can do the forwarding if all the other
185 /// potentially aliasing ops that may reach the transfer_read are post-dominated
186 /// by the transfer_write.
187 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
188  if (read.hasOutOfBoundsDim())
189  return;
190  LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
191  << "\n");
192  SmallVector<Operation *, 8> blockingWrites;
193  vector::TransferWriteOp lastwrite = nullptr;
194  Value source = read.getSource();
195  // Skip subview ops.
196  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
197  source = subView.getSource();
198  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
199  source.getUsers().end());
200  llvm::SmallDenseSet<Operation *, 32> processed;
201  while (!users.empty()) {
202  Operation *user = users.pop_back_val();
203  // If the user has already been processed skip.
204  if (!processed.insert(user).second)
205  continue;
206  if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
207  users.append(subView->getUsers().begin(), subView->getUsers().end());
208  continue;
209  }
210  if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
211  users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
212  continue;
213  }
214  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
215  continue;
216  if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
217  // If there is a write, but we can prove that it is disjoint we can ignore
218  // the write.
220  cast<VectorTransferOpInterface>(write.getOperation()),
221  cast<VectorTransferOpInterface>(read.getOperation()),
222  /*testDynamicValueUsingBounds=*/true))
223  continue;
224  if (write.getSource() == read.getSource() &&
225  dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
226  if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
227  lastwrite = write;
228  else
229  assert(dominators.dominates(write, lastwrite));
230  continue;
231  }
232  }
233  blockingWrites.push_back(user);
234  }
235 
236  if (lastwrite == nullptr)
237  return;
238 
239  Region *topRegion = lastwrite->getParentRegion();
240  Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
241  assert(readAncestor &&
242  "read op should be recursively part of the top region");
243 
244  for (Operation *write : blockingWrites) {
245  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
246  // TODO: if the store and read have the same ancestor we could recurse in
247  // the region to know if the read is reachable with more precision.
248  if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
249  continue;
250  if (!postDominators.postDominates(lastwrite, write)) {
251  LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
252  << *write << "\n");
253  return;
254  }
255  }
256 
257  LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
258  << " to: " << *read.getOperation() << "\n");
259  read.replaceAllUsesWith(lastwrite.getVector());
260  opToErase.push_back(read.getOperation());
261 }
262 
263 /// Converts OpFoldResults to int64_t shape without unit dims.
265  SmallVector<int64_t> reducedShape;
266  for (const auto size : mixedSizes) {
267  if (llvm::dyn_cast_if_present<Value>(size)) {
268  reducedShape.push_back(ShapedType::kDynamic);
269  continue;
270  }
271 
272  auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
273  if (value == 1)
274  continue;
275  reducedShape.push_back(value.getSExtValue());
276  }
277  return reducedShape;
278 }
279 
280 /// Drops unit dimensions from the input MemRefType.
281 static MemRefType dropUnitDims(MemRefType inputType,
282  ArrayRef<OpFoldResult> offsets,
284  ArrayRef<OpFoldResult> strides) {
285  auto targetShape = getReducedShape(sizes);
286  Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
287  targetShape, inputType, offsets, sizes, strides);
288  return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
289 }
290 
291 /// Creates a rank-reducing memref.subview op that drops unit dims from its
292 /// input. Or just returns the input if it was already without unit dims.
294  mlir::Location loc,
295  Value input) {
296  MemRefType inputType = cast<MemRefType>(input.getType());
297  SmallVector<OpFoldResult> offsets(inputType.getRank(),
298  rewriter.getIndexAttr(0));
299  SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
300  SmallVector<OpFoldResult> strides(inputType.getRank(),
301  rewriter.getIndexAttr(1));
302  MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
303 
304  if (canonicalizeStridedLayout(resultType) ==
305  canonicalizeStridedLayout(inputType))
306  return input;
307  return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
308  sizes, strides);
309 }
310 
311 /// Returns the number of dims that aren't unit dims.
313  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
314 }
315 
316 /// Trims non-scalable one dimensions from `oldType` and returns the result
317 /// type.
318 static VectorType trimNonScalableUnitDims(VectorType oldType) {
319  SmallVector<int64_t> newShape;
320  SmallVector<bool> newScalableDims;
321  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
322  if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
323  continue;
324  newShape.push_back(dimSize);
325  newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
326  }
327  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
328 }
329 
330 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
331 static FailureOr<Value>
333  vector::CreateMaskOp op) {
334  auto type = op.getType();
335  VectorType reducedType = trimNonScalableUnitDims(type);
336  if (reducedType.getRank() == type.getRank())
337  return failure();
338 
339  SmallVector<Value> reducedOperands;
340  for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
341  type.getShape(), type.getScalableDims(), op.getOperands())) {
342  if (dim == 1 && !dimIsScalable) {
343  // If the mask for the unit dim is not a constant of 1, do nothing.
344  auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
345  if (!constant || (constant.value() != 1))
346  return failure();
347  continue;
348  }
349  reducedOperands.push_back(operand);
350  }
351  return rewriter
352  .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
353  .getResult();
354 }
355 
356 namespace {
357 
358 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
359 /// inserting a memref.subview dropping those unit dims. The vector shapes are
360 /// also reduced accordingly.
361 class TransferReadDropUnitDimsPattern
362  : public OpRewritePattern<vector::TransferReadOp> {
364 
365  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
366  PatternRewriter &rewriter) const override {
367  auto loc = transferReadOp.getLoc();
368  Value vector = transferReadOp.getVector();
369  VectorType vectorType = cast<VectorType>(vector.getType());
370  Value source = transferReadOp.getSource();
371  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
372  // TODO: support tensor types.
373  if (!sourceType)
374  return failure();
375  // TODO: generalize this pattern, relax the requirements here.
376  if (transferReadOp.hasOutOfBoundsDim())
377  return failure();
378  if (!transferReadOp.getPermutationMap().isMinorIdentity())
379  return failure();
380  // Check if the source shape can be further reduced.
381  int reducedRank = getReducedRank(sourceType.getShape());
382  if (reducedRank == sourceType.getRank())
383  return failure();
384  // Check if the reduced vector shape matches the reduced source shape.
385  // Otherwise, this case is not supported yet.
386  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
387  if (reducedRank != reducedVectorType.getRank())
388  return failure();
389  if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
390  return getConstantIntValue(v) != static_cast<int64_t>(0);
391  }))
392  return failure();
393 
394  Value maskOp = transferReadOp.getMask();
395  if (maskOp) {
396  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
397  if (!createMaskOp)
398  return rewriter.notifyMatchFailure(
399  transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
400  "currently supported");
401  FailureOr<Value> rankReducedCreateMask =
402  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
403  if (failed(rankReducedCreateMask))
404  return failure();
405  maskOp = *rankReducedCreateMask;
406  }
407 
408  Value reducedShapeSource =
409  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
410  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
411  SmallVector<Value> zeros(reducedRank, c0);
412  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
413  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
414  auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
415  loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
416  transferReadOp.getPadding(), maskOp,
417  rewriter.getBoolArrayAttr(inBounds));
418  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
419  loc, vectorType, newTransferReadOp);
420  rewriter.replaceOp(transferReadOp, shapeCast);
421 
422  return success();
423  }
424 };
425 
426 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
427 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
428 /// vector shapes are also reduced accordingly.
429 class TransferWriteDropUnitDimsPattern
430  : public OpRewritePattern<vector::TransferWriteOp> {
432 
433  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
434  PatternRewriter &rewriter) const override {
435  auto loc = transferWriteOp.getLoc();
436  Value vector = transferWriteOp.getVector();
437  VectorType vectorType = cast<VectorType>(vector.getType());
438  Value source = transferWriteOp.getSource();
439  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
440  // TODO: support tensor type.
441  if (!sourceType)
442  return failure();
443  // TODO: generalize this pattern, relax the requirements here.
444  if (transferWriteOp.hasOutOfBoundsDim())
445  return failure();
446  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
447  return failure();
448  // Check if the destination shape can be further reduced.
449  int reducedRank = getReducedRank(sourceType.getShape());
450  if (reducedRank == sourceType.getRank())
451  return failure();
452  // Check if the reduced vector shape matches the reduced destination shape.
453  // Otherwise, this case is not supported yet.
454  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
455  if (reducedRank != reducedVectorType.getRank())
456  return failure();
457  if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
458  return getConstantIntValue(v) != static_cast<int64_t>(0);
459  }))
460  return failure();
461 
462  Value maskOp = transferWriteOp.getMask();
463  if (maskOp) {
464  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
465  if (!createMaskOp)
466  return rewriter.notifyMatchFailure(
467  transferWriteOp,
468  "unsupported mask op, only 'vector.create_mask' is "
469  "currently supported");
470  FailureOr<Value> rankReducedCreateMask =
471  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
472  if (failed(rankReducedCreateMask))
473  return failure();
474  maskOp = *rankReducedCreateMask;
475  }
476  Value reducedShapeSource =
477  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
478  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
479  SmallVector<Value> zeros(reducedRank, c0);
480  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
481  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
482  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
483  loc, reducedVectorType, vector);
484  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
485  transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
486  identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
487 
488  return success();
489  }
490 };
491 
492 } // namespace
493 
494 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
495 /// input starting at `firstDimToCollapse`.
497  Value input, int64_t firstDimToCollapse) {
498  ShapedType inputType = cast<ShapedType>(input.getType());
499  if (inputType.getRank() == 1)
500  return input;
501  SmallVector<ReassociationIndices> reassociation;
502  for (int64_t i = 0; i < firstDimToCollapse; ++i)
503  reassociation.push_back(ReassociationIndices{i});
504  ReassociationIndices collapsedIndices;
505  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
506  collapsedIndices.push_back(i);
507  reassociation.push_back(collapsedIndices);
508  return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
509 }
510 
511 /// Checks that the indices corresponding to dimensions starting at
512 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
513 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
514 /// TODO: Extract the logic that writes to outIndices so that this method
515 /// simply checks one pre-condition.
516 static LogicalResult
517 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
518  SmallVector<Value> &outIndices) {
519  int64_t rank = indices.size();
520  if (firstDimToCollapse >= rank)
521  return failure();
522  for (int64_t i = firstDimToCollapse; i < rank; ++i) {
523  std::optional<int64_t> cst = getConstantIntValue(indices[i]);
524  if (!cst || cst.value() != 0)
525  return failure();
526  }
527  outIndices = indices;
528  outIndices.resize(firstDimToCollapse + 1);
529  return success();
530 }
531 
532 namespace {
533 
534 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
535 /// memref.collapse_shape on the source so that the resulting
536 /// vector.transfer_read has a 1D source. Requires the source shape to be
537 /// already reduced i.e. without unit dims.
538 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
539 /// the trailing dimension of the vector read is smaller than the provided
540 /// bitwidth.
541 class FlattenContiguousRowMajorTransferReadPattern
542  : public OpRewritePattern<vector::TransferReadOp> {
543 public:
544  FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
545  unsigned vectorBitwidth,
546  PatternBenefit benefit)
547  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
548  targetVectorBitwidth(vectorBitwidth) {}
549 
550  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
551  PatternRewriter &rewriter) const override {
552  auto loc = transferReadOp.getLoc();
553  Value vector = transferReadOp.getVector();
554  VectorType vectorType = cast<VectorType>(vector.getType());
555  auto source = transferReadOp.getSource();
556  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
557 
558  // 0. Check pre-conditions
559  // Contiguity check is valid on tensors only.
560  if (!sourceType)
561  return failure();
562  // If this is already 0D/1D, there's nothing to do.
563  if (vectorType.getRank() <= 1)
564  return failure();
565  if (!vectorType.getElementType().isSignlessIntOrFloat())
566  return failure();
567  unsigned trailingVectorDimBitwidth =
568  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
569  if (trailingVectorDimBitwidth >= targetVectorBitwidth)
570  return failure();
571  if (!vector::isContiguousSlice(sourceType, vectorType))
572  return failure();
573  // TODO: generalize this pattern, relax the requirements here.
574  if (transferReadOp.hasOutOfBoundsDim())
575  return failure();
576  if (!transferReadOp.getPermutationMap().isMinorIdentity())
577  return failure();
578  if (transferReadOp.getMask())
579  return failure();
580 
581  int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
582 
583  // 1. Collapse the source memref
584  Value collapsedSource =
585  collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
586  MemRefType collapsedSourceType =
587  dyn_cast<MemRefType>(collapsedSource.getType());
588  int64_t collapsedRank = collapsedSourceType.getRank();
589  assert(collapsedRank == firstDimToCollapse + 1);
590 
591  // 2. Generate input args for a new vector.transfer_read that will read
592  // from the collapsed memref.
593  // 2.1. New dim exprs + affine map
595  getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
596  auto collapsedMap =
597  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
598 
599  // 2.2 New indices
600  // If all the collapsed indices are zero then no extra logic is needed.
601  // Otherwise, a new offset/index has to be computed.
602  SmallVector<Value> collapsedIndices;
603  if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
604  firstDimToCollapse,
605  collapsedIndices))) {
606  // Copy all the leading indices.
607  SmallVector<Value> indices = transferReadOp.getIndices();
608  collapsedIndices.append(indices.begin(),
609  indices.begin() + firstDimToCollapse);
610 
611  // Compute the remaining trailing index/offset required for reading from
612  // the collapsed memref:
613  //
614  // offset = 0
615  // for (i = firstDimToCollapse; i < outputRank; ++i)
616  // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
617  //
618  // For this example:
619  // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
620  // memref<1x43x2xi32>, vector<1x2xi32>
621  // which would be collapsed to:
622  // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
623  // memref<1x86xi32>, vector<2xi32>
624  // one would get the following offset:
625  // %offset = %arg0 * 43
626  OpFoldResult collapsedOffset =
627  rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
628 
629  auto sourceShape = sourceType.getShape();
630  auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>(
631  sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
632 
633  // Compute the collapsed offset.
634  ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
635  indices.end());
636  auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
637  collapsedOffset, collapsedStrides, indicesToCollapse);
638  collapsedOffset = affine::makeComposedFoldedAffineApply(
639  rewriter, loc, collapsedExpr, collapsedVals);
640 
641  if (collapsedOffset.is<Value>()) {
642  collapsedIndices.push_back(collapsedOffset.get<Value>());
643  } else {
644  collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
645  loc, *getConstantIntValue(collapsedOffset)));
646  }
647  }
648 
649  // 3. Create new vector.transfer_read that reads from the collapsed memref
650  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
651  vectorType.getElementType());
652  vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
653  loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
654  flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
655 
656  // 4. Replace the old transfer_read with the new one reading from the
657  // collapsed shape
658  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
659  transferReadOp, cast<VectorType>(vector.getType()), flatRead);
660  return success();
661  }
662 
663 private:
664  // Minimum bitwidth that the trailing vector dimension should have after
665  // flattening.
666  unsigned targetVectorBitwidth;
667 };
668 
669 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
670 /// memref.collapse_shape on the source so that the resulting
671 /// vector.transfer_write has a 1D source. Requires the source shape to be
672 /// already reduced i.e. without unit dims.
673 class FlattenContiguousRowMajorTransferWritePattern
674  : public OpRewritePattern<vector::TransferWriteOp> {
675 public:
676  FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
677  unsigned vectorBitwidth,
678  PatternBenefit benefit)
679  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
680  targetVectorBitwidth(vectorBitwidth) {}
681 
682  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
683  PatternRewriter &rewriter) const override {
684  auto loc = transferWriteOp.getLoc();
685  Value vector = transferWriteOp.getVector();
686  VectorType vectorType = cast<VectorType>(vector.getType());
687  Value source = transferWriteOp.getSource();
688  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
689  // Contiguity check is valid on tensors only.
690  if (!sourceType)
691  return failure();
692  if (vectorType.getRank() <= 1)
693  // Already 0D/1D, nothing to do.
694  return failure();
695  if (!vectorType.getElementType().isSignlessIntOrFloat())
696  return failure();
697  unsigned trailingVectorDimBitwidth =
698  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
699  if (trailingVectorDimBitwidth >= targetVectorBitwidth)
700  return failure();
701  if (!vector::isContiguousSlice(sourceType, vectorType))
702  return failure();
703  int64_t firstContiguousInnerDim =
704  sourceType.getRank() - vectorType.getRank();
705  // TODO: generalize this pattern, relax the requirements here.
706  if (transferWriteOp.hasOutOfBoundsDim())
707  return failure();
708  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
709  return failure();
710  if (transferWriteOp.getMask())
711  return failure();
712  SmallVector<Value> collapsedIndices;
713  if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
714  firstContiguousInnerDim,
715  collapsedIndices)))
716  return failure();
717 
718  Value collapsedSource =
719  collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
720  MemRefType collapsedSourceType =
721  cast<MemRefType>(collapsedSource.getType());
722  int64_t collapsedRank = collapsedSourceType.getRank();
723  assert(collapsedRank == firstContiguousInnerDim + 1);
725  getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
726  auto collapsedMap =
727  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
728  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
729  vectorType.getElementType());
730  Value flatVector =
731  rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
732  vector::TransferWriteOp flatWrite =
733  rewriter.create<vector::TransferWriteOp>(
734  loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
735  flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
736  rewriter.eraseOp(transferWriteOp);
737  return success();
738  }
739 
740 private:
741  // Minimum bitwidth that the trailing vector dimension should have after
742  // flattening.
743  unsigned targetVectorBitwidth;
744 };
745 
746 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
747 /// to `memref.load` patterns. The `match` method is shared for both
748 /// `vector.extract` and `vector.extract_element`.
749 template <class VectorExtractOp>
750 class RewriteScalarExtractOfTransferReadBase
751  : public OpRewritePattern<VectorExtractOp> {
753 
754 public:
755  RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
756  PatternBenefit benefit,
757  bool allowMultipleUses)
758  : Base::OpRewritePattern(context, benefit),
759  allowMultipleUses(allowMultipleUses) {}
760 
761  LogicalResult match(VectorExtractOp extractOp) const override {
762  auto xferOp =
763  extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
764  if (!xferOp)
765  return failure();
766  // Check that we are extracting a scalar and not a sub-vector.
767  if (isa<VectorType>(extractOp.getResult().getType()))
768  return failure();
769  // If multiple uses are not allowed, check if xfer has a single use.
770  if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
771  return failure();
772  // If multiple uses are allowed, check if all the xfer uses are extract ops.
773  if (allowMultipleUses &&
774  !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
775  return isa<vector::ExtractOp, vector::ExtractElementOp>(
776  use.getOwner());
777  }))
778  return failure();
779  // Mask not supported.
780  if (xferOp.getMask())
781  return failure();
782  // Map not supported.
783  if (!xferOp.getPermutationMap().isMinorIdentity())
784  return failure();
785  // Cannot rewrite if the indices may be out of bounds.
786  if (xferOp.hasOutOfBoundsDim())
787  return failure();
788  return success();
789  }
790 
791 private:
792  bool allowMultipleUses;
793 };
794 
795 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
796 ///
797 /// All the users of the transfer op must be either `vector.extractelement` or
798 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
799 /// transfer ops with any number of users. Otherwise, rewrite only if the
800 /// extract op is the single user of the transfer op. Rewriting a single
801 /// vector load with multiple scalar loads may negatively affect performance.
802 class RewriteScalarExtractElementOfTransferRead
803  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
804  using RewriteScalarExtractOfTransferReadBase::
805  RewriteScalarExtractOfTransferReadBase;
806 
807  void rewrite(vector::ExtractElementOp extractOp,
808  PatternRewriter &rewriter) const override {
809  // Construct scalar load.
810  auto loc = extractOp.getLoc();
811  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
812  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
813  xferOp.getIndices().end());
814  if (extractOp.getPosition()) {
815  AffineExpr sym0, sym1;
816  bindSymbols(extractOp.getContext(), sym0, sym1);
818  rewriter, loc, sym0 + sym1,
819  {newIndices[newIndices.size() - 1], extractOp.getPosition()});
820  if (ofr.is<Value>()) {
821  newIndices[newIndices.size() - 1] = ofr.get<Value>();
822  } else {
823  newIndices[newIndices.size() - 1] =
824  rewriter.create<arith::ConstantIndexOp>(loc,
825  *getConstantIntValue(ofr));
826  }
827  }
828  if (isa<MemRefType>(xferOp.getSource().getType())) {
829  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
830  newIndices);
831  } else {
832  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
833  extractOp, xferOp.getSource(), newIndices);
834  }
835  }
836 };
837 
838 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
839 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
840 ///
841 /// All the users of the transfer op must be either `vector.extractelement` or
842 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
843 /// transfer ops with any number of users. Otherwise, rewrite only if the
844 /// extract op is the single user of the transfer op. Rewriting a single
845 /// vector load with multiple scalar loads may negatively affect performance.
846 class RewriteScalarExtractOfTransferRead
847  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
848  using RewriteScalarExtractOfTransferReadBase::
849  RewriteScalarExtractOfTransferReadBase;
850 
851  void rewrite(vector::ExtractOp extractOp,
852  PatternRewriter &rewriter) const override {
853  // Construct scalar load.
854  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
855  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
856  xferOp.getIndices().end());
857  for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
858  assert(pos.is<Attribute>() && "Unexpected non-constant index");
859  int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
860  int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
862  rewriter, extractOp.getLoc(),
863  rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
864  if (ofr.is<Value>()) {
865  newIndices[idx] = ofr.get<Value>();
866  } else {
867  newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
868  extractOp.getLoc(), *getConstantIntValue(ofr));
869  }
870  }
871  if (isa<MemRefType>(xferOp.getSource().getType())) {
872  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
873  newIndices);
874  } else {
875  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
876  extractOp, xferOp.getSource(), newIndices);
877  }
878  }
879 };
880 
881 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
882 /// to memref.store.
883 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
885 
886  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
887  PatternRewriter &rewriter) const override {
888  // Must be a scalar write.
889  auto vecType = xferOp.getVectorType();
890  if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
891  return failure();
892  // Mask not supported.
893  if (xferOp.getMask())
894  return failure();
895  // Map not supported.
896  if (!xferOp.getPermutationMap().isMinorIdentity())
897  return failure();
898  // Only float and integer element types are supported.
899  Value scalar;
900  if (vecType.getRank() == 0) {
901  // vector.extract does not support vector<f32> etc., so use
902  // vector.extractelement instead.
903  scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
904  xferOp.getVector());
905  } else {
906  SmallVector<int64_t> pos(vecType.getRank(), 0);
907  scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
908  xferOp.getVector(), pos);
909  }
910  // Construct a scalar store.
911  if (isa<MemRefType>(xferOp.getSource().getType())) {
912  rewriter.replaceOpWithNewOp<memref::StoreOp>(
913  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
914  } else {
915  rewriter.replaceOpWithNewOp<tensor::InsertOp>(
916  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
917  }
918  return success();
919  }
920 };
921 
922 } // namespace
923 
925  Operation *rootOp) {
926  TransferOptimization opt(rewriter, rootOp);
927  // Run store to load forwarding first since it can expose more dead store
928  // opportunity.
929  rootOp->walk([&](vector::TransferReadOp read) {
930  if (isa<MemRefType>(read.getShapedType()))
931  opt.storeToLoadForwarding(read);
932  });
933  opt.removeDeadOp();
934  rootOp->walk([&](vector::TransferWriteOp write) {
935  if (isa<MemRefType>(write.getShapedType()))
936  opt.deadStoreOp(write);
937  });
938  opt.removeDeadOp();
939 }
940 
942  RewritePatternSet &patterns, PatternBenefit benefit,
943  bool allowMultipleUses) {
944  patterns.add<RewriteScalarExtractElementOfTransferRead,
945  RewriteScalarExtractOfTransferRead>(patterns.getContext(),
946  benefit, allowMultipleUses);
947  patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
948 }
949 
951  RewritePatternSet &patterns, PatternBenefit benefit) {
952  patterns
953  .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
954  patterns.getContext(), benefit);
956 }
957 
959  RewritePatternSet &patterns, unsigned targetVectorBitwidth,
960  PatternBenefit benefit) {
961  patterns.add<FlattenContiguousRowMajorTransferReadPattern,
962  FlattenContiguousRowMajorTransferWritePattern>(
963  patterns.getContext(), targetVectorBitwidth, benefit);
964  populateShapeCastFoldingPatterns(patterns, benefit);
965  populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
966 }
static bool isReachable(Block *from, Block *to, ArrayRef< Block * > except)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input)
Creates a rank-reducing memref.subview op that drops unit dims from its input.
static int getReducedRank(ArrayRef< int64_t > shape)
Returns the number of dims that aren't unit dims.
static SmallVector< int64_t > getReducedShape(ArrayRef< OpFoldResult > mixedSizes)
Converts OpFoldResults to int64_t shape without unit dims.
static VectorType trimNonScalableUnitDims(VectorType oldType)
Trims non-scalable one dimensions from oldType and returns the result type.
static FailureOr< Value > createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, vector::CreateMaskOp op)
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
#define DBGS()
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
Drops unit dimensions from the input MemRefType.
static LogicalResult checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, SmallVector< Value > &outIndices)
Checks that the indices corresponding to dimensions starting at firstDimToCollapse are constant 0,...
Base type for affine expression.
Definition: AffineExpr.h:69
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
succ_iterator succ_end()
Definition: Block.h:263
succ_iterator succ_begin()
Definition: Block.h:262
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:375
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
A class for computing basic dominance information.
Definition: Dominance.h:136
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
A class for computing basic postdominance information.
Definition: Dominance.h:195
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_range getUsers() const
Definition: Value.h:224
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, unsigned targetVectorBitwidth=std::numeric_limits< unsigned >::max(), PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:253
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362