MLIR  19.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
24 #include "mlir/IR/Dominance.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Debug.h"
29 
30 #define DEBUG_TYPE "vector-transfer-opt"
31 
32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
33 
34 using namespace mlir;
35 
36 /// Return the ancestor op in the region or nullptr if the region is not
37 /// an ancestor of the op.
39  for (; op != nullptr && op->getParentRegion() != region;
40  op = op->getParentOp())
41  ;
42  return op;
43 }
44 
45 namespace {
46 
47 class TransferOptimization {
48 public:
49  TransferOptimization(RewriterBase &rewriter, Operation *op)
50  : rewriter(rewriter), dominators(op), postDominators(op) {}
51  void deadStoreOp(vector::TransferWriteOp);
52  void storeToLoadForwarding(vector::TransferReadOp);
53  void removeDeadOp() {
54  for (Operation *op : opToErase)
55  rewriter.eraseOp(op);
56  opToErase.clear();
57  }
58 
59 private:
60  RewriterBase &rewriter;
61  bool isReachable(Operation *start, Operation *dest);
62  DominanceInfo dominators;
63  PostDominanceInfo postDominators;
64  std::vector<Operation *> opToErase;
65 };
66 
67 } // namespace
68 /// Return true if there is a path from start operation to dest operation,
69 /// otherwise return false. The operations have to be in the same region.
71  assert(start->getParentRegion() == dest->getParentRegion() &&
72  "This function only works for ops i the same region");
73  // Simple case where the start op dominate the destination.
74  if (dominators.dominates(start, dest))
75  return true;
76  Block *startBlock = start->getBlock();
77  Block *destBlock = dest->getBlock();
78  SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
79  startBlock->succ_end());
81  while (!worklist.empty()) {
82  Block *bb = worklist.pop_back_val();
83  if (!visited.insert(bb).second)
84  continue;
85  if (dominators.dominates(bb, destBlock))
86  return true;
87  worklist.append(bb->succ_begin(), bb->succ_end());
88  }
89  return false;
90 }
91 
92 /// For transfer_write to overwrite fully another transfer_write must:
93 /// 1. Access the same memref with the same indices and vector type.
94 /// 2. Post-dominate the other transfer_write operation.
95 /// If several candidates are available, one must be post-dominated by all the
96 /// others since they are all post-dominating the same transfer_write. We only
97 /// consider the transfer_write post-dominated by all the other candidates as
98 /// this will be the first transfer_write executed after the potentially dead
99 /// transfer_write.
100 /// If we found such an overwriting transfer_write we know that the original
101 /// transfer_write is dead if all reads that can be reached from the potentially
102 /// dead transfer_write are dominated by the overwriting transfer_write.
103 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
104  LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
105  << "\n");
106  llvm::SmallVector<Operation *, 8> blockingAccesses;
107  Operation *firstOverwriteCandidate = nullptr;
108  Value source =
109  memref::skipSubViewsAndCasts(cast<MemrefValue>(write.getSource()));
110  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
111  source.getUsers().end());
112  llvm::SmallDenseSet<Operation *, 32> processed;
113  while (!users.empty()) {
114  Operation *user = users.pop_back_val();
115  // If the user has already been processed skip.
116  if (!processed.insert(user).second)
117  continue;
118  if (isa<memref::SubViewOp, memref::CastOp>(user)) {
119  users.append(user->getUsers().begin(), user->getUsers().end());
120  continue;
121  }
122  if (isMemoryEffectFree(user))
123  continue;
124  if (user == write.getOperation())
125  continue;
126  if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
127  // Check candidate that can override the store.
129  cast<MemrefValue>(nextWrite.getSource()),
130  cast<MemrefValue>(write.getSource())) &&
131  checkSameValueWAW(nextWrite, write) &&
132  postDominators.postDominates(nextWrite, write)) {
133  if (firstOverwriteCandidate == nullptr ||
134  postDominators.postDominates(firstOverwriteCandidate, nextWrite))
135  firstOverwriteCandidate = nextWrite;
136  else
137  assert(
138  postDominators.postDominates(nextWrite, firstOverwriteCandidate));
139  continue;
140  }
141  }
142  if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
143  // Don't need to consider disjoint accesses.
145  cast<VectorTransferOpInterface>(write.getOperation()),
146  cast<VectorTransferOpInterface>(transferOp.getOperation()),
147  /*testDynamicValueUsingBounds=*/true))
148  continue;
149  }
150  blockingAccesses.push_back(user);
151  }
152  if (firstOverwriteCandidate == nullptr)
153  return;
154  Region *topRegion = firstOverwriteCandidate->getParentRegion();
155  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
156  assert(writeAncestor &&
157  "write op should be recursively part of the top region");
158 
159  for (Operation *access : blockingAccesses) {
160  Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
161  // TODO: if the access and write have the same ancestor we could recurse in
162  // the region to know if the access is reachable with more precision.
163  if (accessAncestor == nullptr ||
164  !isReachable(writeAncestor, accessAncestor))
165  continue;
166  if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
167  LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
168  << *accessAncestor << "\n");
169  return;
170  }
171  }
172  LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
173  << " overwritten by: " << *firstOverwriteCandidate << "\n");
174  opToErase.push_back(write.getOperation());
175 }
176 
177 /// A transfer_write candidate to storeToLoad forwarding must:
178 /// 1. Access the same memref with the same indices and vector type as the
179 /// transfer_read.
180 /// 2. Dominate the transfer_read operation.
181 /// If several candidates are available, one must be dominated by all the others
182 /// since they are all dominating the same transfer_read. We only consider the
183 /// transfer_write dominated by all the other candidates as this will be the
184 /// last transfer_write executed before the transfer_read.
185 /// If we found such a candidate we can do the forwarding if all the other
186 /// potentially aliasing ops that may reach the transfer_read are post-dominated
187 /// by the transfer_write.
188 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
189  if (read.hasOutOfBoundsDim())
190  return;
191  LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
192  << "\n");
193  SmallVector<Operation *, 8> blockingWrites;
194  vector::TransferWriteOp lastwrite = nullptr;
195  Value source =
196  memref::skipSubViewsAndCasts(cast<MemrefValue>(read.getSource()));
197  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
198  source.getUsers().end());
199  llvm::SmallDenseSet<Operation *, 32> processed;
200  while (!users.empty()) {
201  Operation *user = users.pop_back_val();
202  // If the user has already been processed skip.
203  if (!processed.insert(user).second)
204  continue;
205  if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
206  users.append(user->getUsers().begin(), user->getUsers().end());
207  continue;
208  }
209  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
210  continue;
211  if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
212  // If there is a write, but we can prove that it is disjoint we can ignore
213  // the write.
215  cast<VectorTransferOpInterface>(write.getOperation()),
216  cast<VectorTransferOpInterface>(read.getOperation()),
217  /*testDynamicValueUsingBounds=*/true))
218  continue;
220  cast<MemrefValue>(read.getSource()),
221  cast<MemrefValue>(write.getSource())) &&
222  dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
223  if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
224  lastwrite = write;
225  else
226  assert(dominators.dominates(write, lastwrite));
227  continue;
228  }
229  }
230  blockingWrites.push_back(user);
231  }
232 
233  if (lastwrite == nullptr)
234  return;
235 
236  Region *topRegion = lastwrite->getParentRegion();
237  Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
238  assert(readAncestor &&
239  "read op should be recursively part of the top region");
240 
241  for (Operation *write : blockingWrites) {
242  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
243  // TODO: if the store and read have the same ancestor we could recurse in
244  // the region to know if the read is reachable with more precision.
245  if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
246  continue;
247  if (!postDominators.postDominates(lastwrite, write)) {
248  LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
249  << *write << "\n");
250  return;
251  }
252  }
253 
254  LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
255  << " to: " << *read.getOperation() << "\n");
256  read.replaceAllUsesWith(lastwrite.getVector());
257  opToErase.push_back(read.getOperation());
258 }
259 
260 /// Converts OpFoldResults to int64_t shape without unit dims.
262  SmallVector<int64_t> reducedShape;
263  for (const auto size : mixedSizes) {
264  if (llvm::dyn_cast_if_present<Value>(size)) {
265  reducedShape.push_back(ShapedType::kDynamic);
266  continue;
267  }
268 
269  auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
270  if (value == 1)
271  continue;
272  reducedShape.push_back(value.getSExtValue());
273  }
274  return reducedShape;
275 }
276 
277 /// Drops unit dimensions from the input MemRefType.
278 static MemRefType dropUnitDims(MemRefType inputType,
279  ArrayRef<OpFoldResult> offsets,
281  ArrayRef<OpFoldResult> strides) {
282  auto targetShape = getReducedShape(sizes);
283  Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
284  targetShape, inputType, offsets, sizes, strides);
285  return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
286 }
287 
288 /// Creates a rank-reducing memref.subview op that drops unit dims from its
289 /// input. Or just returns the input if it was already without unit dims.
291  mlir::Location loc,
292  Value input) {
293  MemRefType inputType = cast<MemRefType>(input.getType());
294  SmallVector<OpFoldResult> offsets(inputType.getRank(),
295  rewriter.getIndexAttr(0));
296  SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
297  SmallVector<OpFoldResult> strides(inputType.getRank(),
298  rewriter.getIndexAttr(1));
299  MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
300 
301  if (canonicalizeStridedLayout(resultType) ==
302  canonicalizeStridedLayout(inputType))
303  return input;
304  return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
305  sizes, strides);
306 }
307 
308 /// Returns the number of dims that aren't unit dims.
310  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
311 }
312 
313 /// Trims non-scalable one dimensions from `oldType` and returns the result
314 /// type.
315 static VectorType trimNonScalableUnitDims(VectorType oldType) {
316  SmallVector<int64_t> newShape;
317  SmallVector<bool> newScalableDims;
318  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
319  if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
320  continue;
321  newShape.push_back(dimSize);
322  newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
323  }
324  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
325 }
326 
327 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
328 static FailureOr<Value>
330  vector::CreateMaskOp op) {
331  auto type = op.getType();
332  VectorType reducedType = trimNonScalableUnitDims(type);
333  if (reducedType.getRank() == type.getRank())
334  return failure();
335 
336  SmallVector<Value> reducedOperands;
337  for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
338  type.getShape(), type.getScalableDims(), op.getOperands())) {
339  if (dim == 1 && !dimIsScalable) {
340  // If the mask for the unit dim is not a constant of 1, do nothing.
341  auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
342  if (!constant || (constant.value() != 1))
343  return failure();
344  continue;
345  }
346  reducedOperands.push_back(operand);
347  }
348  return rewriter
349  .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
350  .getResult();
351 }
352 
353 namespace {
354 
355 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
356 /// inserting a memref.subview dropping those unit dims. The vector shapes are
357 /// also reduced accordingly.
358 class TransferReadDropUnitDimsPattern
359  : public OpRewritePattern<vector::TransferReadOp> {
361 
362  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
363  PatternRewriter &rewriter) const override {
364  auto loc = transferReadOp.getLoc();
365  Value vector = transferReadOp.getVector();
366  VectorType vectorType = cast<VectorType>(vector.getType());
367  Value source = transferReadOp.getSource();
368  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
369  // TODO: support tensor types.
370  if (!sourceType)
371  return failure();
372  // TODO: generalize this pattern, relax the requirements here.
373  if (transferReadOp.hasOutOfBoundsDim())
374  return failure();
375  if (!transferReadOp.getPermutationMap().isMinorIdentity())
376  return failure();
377  // Check if the source shape can be further reduced.
378  int reducedRank = getReducedRank(sourceType.getShape());
379  if (reducedRank == sourceType.getRank())
380  return failure();
381  // Check if the reduced vector shape matches the reduced source shape.
382  // Otherwise, this case is not supported yet.
383  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
384  if (reducedRank != reducedVectorType.getRank())
385  return failure();
386  if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
387  return getConstantIntValue(v) != static_cast<int64_t>(0);
388  }))
389  return failure();
390 
391  Value maskOp = transferReadOp.getMask();
392  if (maskOp) {
393  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
394  if (!createMaskOp)
395  return rewriter.notifyMatchFailure(
396  transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
397  "currently supported");
398  FailureOr<Value> rankReducedCreateMask =
399  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
400  if (failed(rankReducedCreateMask))
401  return failure();
402  maskOp = *rankReducedCreateMask;
403  }
404 
405  Value reducedShapeSource =
406  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
407  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
408  SmallVector<Value> zeros(reducedRank, c0);
409  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
410  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
411  auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
412  loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
413  transferReadOp.getPadding(), maskOp,
414  rewriter.getBoolArrayAttr(inBounds));
415  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
416  loc, vectorType, newTransferReadOp);
417  rewriter.replaceOp(transferReadOp, shapeCast);
418 
419  return success();
420  }
421 };
422 
423 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
424 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
425 /// vector shapes are also reduced accordingly.
426 class TransferWriteDropUnitDimsPattern
427  : public OpRewritePattern<vector::TransferWriteOp> {
429 
430  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
431  PatternRewriter &rewriter) const override {
432  auto loc = transferWriteOp.getLoc();
433  Value vector = transferWriteOp.getVector();
434  VectorType vectorType = cast<VectorType>(vector.getType());
435  Value source = transferWriteOp.getSource();
436  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
437  // TODO: support tensor type.
438  if (!sourceType)
439  return failure();
440  // TODO: generalize this pattern, relax the requirements here.
441  if (transferWriteOp.hasOutOfBoundsDim())
442  return failure();
443  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
444  return failure();
445  // Check if the destination shape can be further reduced.
446  int reducedRank = getReducedRank(sourceType.getShape());
447  if (reducedRank == sourceType.getRank())
448  return failure();
449  // Check if the reduced vector shape matches the reduced destination shape.
450  // Otherwise, this case is not supported yet.
451  VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
452  if (reducedRank != reducedVectorType.getRank())
453  return failure();
454  if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
455  return getConstantIntValue(v) != static_cast<int64_t>(0);
456  }))
457  return failure();
458 
459  Value maskOp = transferWriteOp.getMask();
460  if (maskOp) {
461  auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
462  if (!createMaskOp)
463  return rewriter.notifyMatchFailure(
464  transferWriteOp,
465  "unsupported mask op, only 'vector.create_mask' is "
466  "currently supported");
467  FailureOr<Value> rankReducedCreateMask =
468  createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
469  if (failed(rankReducedCreateMask))
470  return failure();
471  maskOp = *rankReducedCreateMask;
472  }
473  Value reducedShapeSource =
474  rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
475  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
476  SmallVector<Value> zeros(reducedRank, c0);
477  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
478  SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
479  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
480  loc, reducedVectorType, vector);
481  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
482  transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
483  identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
484 
485  return success();
486  }
487 };
488 
489 } // namespace
490 
491 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
492 /// input starting at `firstDimToCollapse`.
494  Value input, int64_t firstDimToCollapse) {
495  ShapedType inputType = cast<ShapedType>(input.getType());
496  if (inputType.getRank() == 1)
497  return input;
498  SmallVector<ReassociationIndices> reassociation;
499  for (int64_t i = 0; i < firstDimToCollapse; ++i)
500  reassociation.push_back(ReassociationIndices{i});
501  ReassociationIndices collapsedIndices;
502  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
503  collapsedIndices.push_back(i);
504  reassociation.push_back(collapsedIndices);
505  return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
506 }
507 
508 /// Checks that the indices corresponding to dimensions starting at
509 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
510 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
511 /// TODO: Extract the logic that writes to outIndices so that this method
512 /// simply checks one pre-condition.
513 static LogicalResult
514 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
515  SmallVector<Value> &outIndices) {
516  int64_t rank = indices.size();
517  if (firstDimToCollapse >= rank)
518  return failure();
519  for (int64_t i = firstDimToCollapse; i < rank; ++i) {
520  std::optional<int64_t> cst = getConstantIntValue(indices[i]);
521  if (!cst || cst.value() != 0)
522  return failure();
523  }
524  outIndices = indices;
525  outIndices.resize(firstDimToCollapse + 1);
526  return success();
527 }
528 
529 namespace {
530 
531 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
532 /// memref.collapse_shape on the source so that the resulting
533 /// vector.transfer_read has a 1D source. Requires the source shape to be
534 /// already reduced i.e. without unit dims.
535 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
536 /// the trailing dimension of the vector read is smaller than the provided
537 /// bitwidth.
538 class FlattenContiguousRowMajorTransferReadPattern
539  : public OpRewritePattern<vector::TransferReadOp> {
540 public:
541  FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
542  unsigned vectorBitwidth,
543  PatternBenefit benefit)
544  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
545  targetVectorBitwidth(vectorBitwidth) {}
546 
547  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
548  PatternRewriter &rewriter) const override {
549  auto loc = transferReadOp.getLoc();
550  Value vector = transferReadOp.getVector();
551  VectorType vectorType = cast<VectorType>(vector.getType());
552  auto source = transferReadOp.getSource();
553  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
554 
555  // 0. Check pre-conditions
556  // Contiguity check is valid on tensors only.
557  if (!sourceType)
558  return failure();
559  // If this is already 0D/1D, there's nothing to do.
560  if (vectorType.getRank() <= 1)
561  return failure();
562  if (!vectorType.getElementType().isSignlessIntOrFloat())
563  return failure();
564  unsigned trailingVectorDimBitwidth =
565  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
566  if (trailingVectorDimBitwidth >= targetVectorBitwidth)
567  return failure();
568  if (!vector::isContiguousSlice(sourceType, vectorType))
569  return failure();
570  // TODO: generalize this pattern, relax the requirements here.
571  if (transferReadOp.hasOutOfBoundsDim())
572  return failure();
573  if (!transferReadOp.getPermutationMap().isMinorIdentity())
574  return failure();
575  if (transferReadOp.getMask())
576  return failure();
577 
578  int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
579 
580  // 1. Collapse the source memref
581  Value collapsedSource =
582  collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
583  MemRefType collapsedSourceType =
584  dyn_cast<MemRefType>(collapsedSource.getType());
585  int64_t collapsedRank = collapsedSourceType.getRank();
586  assert(collapsedRank == firstDimToCollapse + 1);
587 
588  // 2. Generate input args for a new vector.transfer_read that will read
589  // from the collapsed memref.
590  // 2.1. New dim exprs + affine map
592  getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
593  auto collapsedMap =
594  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
595 
596  // 2.2 New indices
597  // If all the collapsed indices are zero then no extra logic is needed.
598  // Otherwise, a new offset/index has to be computed.
599  SmallVector<Value> collapsedIndices;
600  if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
601  firstDimToCollapse,
602  collapsedIndices))) {
603  // Copy all the leading indices.
604  SmallVector<Value> indices = transferReadOp.getIndices();
605  collapsedIndices.append(indices.begin(),
606  indices.begin() + firstDimToCollapse);
607 
608  // Compute the remaining trailing index/offset required for reading from
609  // the collapsed memref:
610  //
611  // offset = 0
612  // for (i = firstDimToCollapse; i < outputRank; ++i)
613  // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
614  //
615  // For this example:
616  // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
617  // memref<1x43x2xi32>, vector<1x2xi32>
618  // which would be collapsed to:
619  // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
620  // memref<1x86xi32>, vector<2xi32>
621  // one would get the following offset:
622  // %offset = %arg0 * 43
623  OpFoldResult collapsedOffset =
624  rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
625 
626  auto sourceShape = sourceType.getShape();
627  auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>(
628  sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
629 
630  // Compute the collapsed offset.
631  ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
632  indices.end());
633  auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
634  collapsedOffset, collapsedStrides, indicesToCollapse);
635  collapsedOffset = affine::makeComposedFoldedAffineApply(
636  rewriter, loc, collapsedExpr, collapsedVals);
637 
638  if (collapsedOffset.is<Value>()) {
639  collapsedIndices.push_back(collapsedOffset.get<Value>());
640  } else {
641  collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
642  loc, *getConstantIntValue(collapsedOffset)));
643  }
644  }
645 
646  // 3. Create new vector.transfer_read that reads from the collapsed memref
647  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
648  vectorType.getElementType());
649  vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
650  loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
651  flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
652 
653  // 4. Replace the old transfer_read with the new one reading from the
654  // collapsed shape
655  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
656  transferReadOp, cast<VectorType>(vector.getType()), flatRead);
657  return success();
658  }
659 
660 private:
661  // Minimum bitwidth that the trailing vector dimension should have after
662  // flattening.
663  unsigned targetVectorBitwidth;
664 };
665 
666 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
667 /// memref.collapse_shape on the source so that the resulting
668 /// vector.transfer_write has a 1D source. Requires the source shape to be
669 /// already reduced i.e. without unit dims.
670 class FlattenContiguousRowMajorTransferWritePattern
671  : public OpRewritePattern<vector::TransferWriteOp> {
672 public:
673  FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
674  unsigned vectorBitwidth,
675  PatternBenefit benefit)
676  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
677  targetVectorBitwidth(vectorBitwidth) {}
678 
679  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
680  PatternRewriter &rewriter) const override {
681  auto loc = transferWriteOp.getLoc();
682  Value vector = transferWriteOp.getVector();
683  VectorType vectorType = cast<VectorType>(vector.getType());
684  Value source = transferWriteOp.getSource();
685  MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
686  // Contiguity check is valid on tensors only.
687  if (!sourceType)
688  return failure();
689  if (vectorType.getRank() <= 1)
690  // Already 0D/1D, nothing to do.
691  return failure();
692  if (!vectorType.getElementType().isSignlessIntOrFloat())
693  return failure();
694  unsigned trailingVectorDimBitwidth =
695  vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
696  if (trailingVectorDimBitwidth >= targetVectorBitwidth)
697  return failure();
698  if (!vector::isContiguousSlice(sourceType, vectorType))
699  return failure();
700  int64_t firstContiguousInnerDim =
701  sourceType.getRank() - vectorType.getRank();
702  // TODO: generalize this pattern, relax the requirements here.
703  if (transferWriteOp.hasOutOfBoundsDim())
704  return failure();
705  if (!transferWriteOp.getPermutationMap().isMinorIdentity())
706  return failure();
707  if (transferWriteOp.getMask())
708  return failure();
709  SmallVector<Value> collapsedIndices;
710  if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
711  firstContiguousInnerDim,
712  collapsedIndices)))
713  return failure();
714 
715  Value collapsedSource =
716  collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
717  MemRefType collapsedSourceType =
718  cast<MemRefType>(collapsedSource.getType());
719  int64_t collapsedRank = collapsedSourceType.getRank();
720  assert(collapsedRank == firstContiguousInnerDim + 1);
722  getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
723  auto collapsedMap =
724  AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
725  VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
726  vectorType.getElementType());
727  Value flatVector =
728  rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
729  vector::TransferWriteOp flatWrite =
730  rewriter.create<vector::TransferWriteOp>(
731  loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
732  flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
733  rewriter.eraseOp(transferWriteOp);
734  return success();
735  }
736 
737 private:
738  // Minimum bitwidth that the trailing vector dimension should have after
739  // flattening.
740  unsigned targetVectorBitwidth;
741 };
742 
743 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
744 /// to `memref.load` patterns. The `match` method is shared for both
745 /// `vector.extract` and `vector.extract_element`.
746 template <class VectorExtractOp>
747 class RewriteScalarExtractOfTransferReadBase
748  : public OpRewritePattern<VectorExtractOp> {
750 
751 public:
752  RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
753  PatternBenefit benefit,
754  bool allowMultipleUses)
755  : Base::OpRewritePattern(context, benefit),
756  allowMultipleUses(allowMultipleUses) {}
757 
758  LogicalResult match(VectorExtractOp extractOp) const override {
759  auto xferOp =
760  extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
761  if (!xferOp)
762  return failure();
763  // Check that we are extracting a scalar and not a sub-vector.
764  if (isa<VectorType>(extractOp.getResult().getType()))
765  return failure();
766  // If multiple uses are not allowed, check if xfer has a single use.
767  if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
768  return failure();
769  // If multiple uses are allowed, check if all the xfer uses are extract ops.
770  if (allowMultipleUses &&
771  !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
772  return isa<vector::ExtractOp, vector::ExtractElementOp>(
773  use.getOwner());
774  }))
775  return failure();
776  // Mask not supported.
777  if (xferOp.getMask())
778  return failure();
779  // Map not supported.
780  if (!xferOp.getPermutationMap().isMinorIdentity())
781  return failure();
782  // Cannot rewrite if the indices may be out of bounds.
783  if (xferOp.hasOutOfBoundsDim())
784  return failure();
785  return success();
786  }
787 
788 private:
789  bool allowMultipleUses;
790 };
791 
792 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
793 ///
794 /// All the users of the transfer op must be either `vector.extractelement` or
795 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
796 /// transfer ops with any number of users. Otherwise, rewrite only if the
797 /// extract op is the single user of the transfer op. Rewriting a single
798 /// vector load with multiple scalar loads may negatively affect performance.
799 class RewriteScalarExtractElementOfTransferRead
800  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
801  using RewriteScalarExtractOfTransferReadBase::
802  RewriteScalarExtractOfTransferReadBase;
803 
804  void rewrite(vector::ExtractElementOp extractOp,
805  PatternRewriter &rewriter) const override {
806  // Construct scalar load.
807  auto loc = extractOp.getLoc();
808  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
809  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
810  xferOp.getIndices().end());
811  if (extractOp.getPosition()) {
812  AffineExpr sym0, sym1;
813  bindSymbols(extractOp.getContext(), sym0, sym1);
815  rewriter, loc, sym0 + sym1,
816  {newIndices[newIndices.size() - 1], extractOp.getPosition()});
817  if (ofr.is<Value>()) {
818  newIndices[newIndices.size() - 1] = ofr.get<Value>();
819  } else {
820  newIndices[newIndices.size() - 1] =
821  rewriter.create<arith::ConstantIndexOp>(loc,
822  *getConstantIntValue(ofr));
823  }
824  }
825  if (isa<MemRefType>(xferOp.getSource().getType())) {
826  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
827  newIndices);
828  } else {
829  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
830  extractOp, xferOp.getSource(), newIndices);
831  }
832  }
833 };
834 
835 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
836 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
837 ///
838 /// All the users of the transfer op must be either `vector.extractelement` or
839 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
840 /// transfer ops with any number of users. Otherwise, rewrite only if the
841 /// extract op is the single user of the transfer op. Rewriting a single
842 /// vector load with multiple scalar loads may negatively affect performance.
843 class RewriteScalarExtractOfTransferRead
844  : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
845  using RewriteScalarExtractOfTransferReadBase::
846  RewriteScalarExtractOfTransferReadBase;
847 
848  void rewrite(vector::ExtractOp extractOp,
849  PatternRewriter &rewriter) const override {
850  // Construct scalar load.
851  auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
852  SmallVector<Value> newIndices(xferOp.getIndices().begin(),
853  xferOp.getIndices().end());
854  for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
855  assert(pos.is<Attribute>() && "Unexpected non-constant index");
856  int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
857  int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
859  rewriter, extractOp.getLoc(),
860  rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
861  if (ofr.is<Value>()) {
862  newIndices[idx] = ofr.get<Value>();
863  } else {
864  newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
865  extractOp.getLoc(), *getConstantIntValue(ofr));
866  }
867  }
868  if (isa<MemRefType>(xferOp.getSource().getType())) {
869  rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
870  newIndices);
871  } else {
872  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
873  extractOp, xferOp.getSource(), newIndices);
874  }
875  }
876 };
877 
878 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
879 /// to memref.store.
880 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
882 
883  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
884  PatternRewriter &rewriter) const override {
885  // Must be a scalar write.
886  auto vecType = xferOp.getVectorType();
887  if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
888  return failure();
889  // Mask not supported.
890  if (xferOp.getMask())
891  return failure();
892  // Map not supported.
893  if (!xferOp.getPermutationMap().isMinorIdentity())
894  return failure();
895  // Only float and integer element types are supported.
896  Value scalar;
897  if (vecType.getRank() == 0) {
898  // vector.extract does not support vector<f32> etc., so use
899  // vector.extractelement instead.
900  scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
901  xferOp.getVector());
902  } else {
903  SmallVector<int64_t> pos(vecType.getRank(), 0);
904  scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
905  xferOp.getVector(), pos);
906  }
907  // Construct a scalar store.
908  if (isa<MemRefType>(xferOp.getSource().getType())) {
909  rewriter.replaceOpWithNewOp<memref::StoreOp>(
910  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
911  } else {
912  rewriter.replaceOpWithNewOp<tensor::InsertOp>(
913  xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
914  }
915  return success();
916  }
917 };
918 
919 } // namespace
920 
922  Operation *rootOp) {
923  TransferOptimization opt(rewriter, rootOp);
924  // Run store to load forwarding first since it can expose more dead store
925  // opportunity.
926  rootOp->walk([&](vector::TransferReadOp read) {
927  if (isa<MemRefType>(read.getShapedType()))
928  opt.storeToLoadForwarding(read);
929  });
930  opt.removeDeadOp();
931  rootOp->walk([&](vector::TransferWriteOp write) {
932  if (isa<MemRefType>(write.getShapedType()))
933  opt.deadStoreOp(write);
934  });
935  opt.removeDeadOp();
936 }
937 
939  RewritePatternSet &patterns, PatternBenefit benefit,
940  bool allowMultipleUses) {
941  patterns.add<RewriteScalarExtractElementOfTransferRead,
942  RewriteScalarExtractOfTransferRead>(patterns.getContext(),
943  benefit, allowMultipleUses);
944  patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
945 }
946 
948  RewritePatternSet &patterns, PatternBenefit benefit) {
949  patterns
950  .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
951  patterns.getContext(), benefit);
953 }
954 
956  RewritePatternSet &patterns, unsigned targetVectorBitwidth,
957  PatternBenefit benefit) {
958  patterns.add<FlattenContiguousRowMajorTransferReadPattern,
959  FlattenContiguousRowMajorTransferWritePattern>(
960  patterns.getContext(), targetVectorBitwidth, benefit);
961  populateShapeCastFoldingPatterns(patterns, benefit);
962  populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
963 }
static bool isReachable(Block *from, Block *to, ArrayRef< Block * > except)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input)
Creates a rank-reducing memref.subview op that drops unit dims from its input.
static int getReducedRank(ArrayRef< int64_t > shape)
Returns the number of dims that aren't unit dims.
static SmallVector< int64_t > getReducedShape(ArrayRef< OpFoldResult > mixedSizes)
Converts OpFoldResults to int64_t shape without unit dims.
static VectorType trimNonScalableUnitDims(VectorType oldType)
Trims non-scalable one dimensions from oldType and returns the result type.
static FailureOr< Value > createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, vector::CreateMaskOp op)
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
#define DBGS()
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
Drops unit dimensions from the input MemRefType.
static LogicalResult checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, SmallVector< Value > &outIndices)
Checks that the indices corresponding to dimensions starting at firstDimToCollapse are constant 0,...
Base type for affine expression.
Definition: AffineExpr.h:69
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
succ_iterator succ_end()
Definition: Block.h:264
succ_iterator succ_begin()
Definition: Block.h:263
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:375
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
A class for computing basic dominance information.
Definition: Dominance.h:136
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
A class for computing basic postdominance information.
Definition: Dominance.h:195
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or are statically known to alias the same region of memory...
Definition: MemRefUtils.h:105
MemrefValue skipSubViewsAndCasts(MemrefValue source)
Walk up the source chain until something an op other than a memref.subview or memref....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, unsigned targetVectorBitwidth=std::numeric_limits< unsigned >::max(), PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:284
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362