MLIR  22.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::memref;
29 
30 /// Materialize a single constant operation from a given attribute value with
31 /// the desired resultant type.
33  Attribute value, Type type,
34  Location loc) {
35  return arith::ConstantOp::materialize(builder, value, type, loc);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Common canonicalization pattern support logic
40 //===----------------------------------------------------------------------===//
41 
42 /// This is a common class used for patterns of the form
43 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
44 /// into the root operation directly.
45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46  bool folded = false;
47  for (OpOperand &operand : op->getOpOperands()) {
48  auto cast = operand.get().getDefiningOp<CastOp>();
49  if (cast && operand.get() != inner &&
50  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51  operand.set(cast.getOperand());
52  folded = true;
53  }
54  }
55  return success(folded);
56 }
57 
58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
59 /// type.
61  if (auto memref = llvm::dyn_cast<MemRefType>(type))
62  return RankedTensorType::get(memref.getShape(), memref.getElementType());
63  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
64  return UnrankedTensorType::get(memref.getElementType());
65  return NoneType::get(type.getContext());
66 }
67 
69  int64_t dim) {
70  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that sets values[i] to constValues[i] if the latter is a
91 /// static value, as indicated by ShapedType::kDynamic.
92 ///
93 /// If constValues[i] is dynamic, tries to extract a constant value from
94 /// value[i] to allow for additional folding opportunities. Also convertes all
95 /// existing attributes to index attributes. (They may be i64 attributes.)
97  ArrayRef<int64_t> constValues) {
98  assert(constValues.size() == values.size() &&
99  "incorrect number of const values");
100  for (auto [i, cstVal] : llvm::enumerate(constValues)) {
101  Builder builder(values[i].getContext());
102  if (ShapedType::isStatic(cstVal)) {
103  // Constant value is known, use it directly.
104  values[i] = builder.getIndexAttr(cstVal);
105  continue;
106  }
107  if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
108  // Try to extract a constant or convert an existing to index.
109  values[i] = builder.getIndexAttr(*cst);
110  }
111  }
112 }
113 
114 /// Helper function to retrieve a lossless memory-space cast, and the
115 /// corresponding new result memref type.
116 static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
118  MemorySpaceCastOpInterface castOp =
119  MemorySpaceCastOpInterface::getIfPromotableCast(src);
120 
121  // Bail if the cast is not lossless.
122  if (!castOp)
123  return {};
124 
125  // Transform the source and target type of `castOp` to have the same metadata
126  // as `resultTy`. Bail if not possible.
127  FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
128  castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
129  if (failed(srcTy))
130  return {};
131 
132  FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
133  castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
134  if (failed(tgtTy))
135  return {};
136 
137  // Check if this is a valid memory-space cast.
138  if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
139  return {};
140 
141  return std::make_tuple(castOp, *tgtTy, *srcTy);
142 }
143 
144 /// Implementation of `bubbleDownCasts` method for memref operations that
145 /// return a single memref result.
146 template <typename ConcreteOpTy>
147 static FailureOr<std::optional<SmallVector<Value>>>
148 bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
149  OpOperand &src) {
150  auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
151  // Bail if we cannot cast.
152  if (!castOp)
153  return failure();
154 
155  // Create the new operands.
156  SmallVector<Value> operands;
157  llvm::append_range(operands, op->getOperands());
158  operands[src.getOperandNumber()] = castOp.getSourcePtr();
159 
160  // Create the new op and results.
161  auto newOp = ConcreteOpTy::create(
162  builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
163  llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
164 
165  // Insert a memory-space cast to the original memory space of the op.
166  MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
167  builder, tgtTy,
168  cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
169  return std::optional<SmallVector<Value>>(
170  SmallVector<Value>({result.getTargetPtr()}));
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // AllocOp / AllocaOp
175 //===----------------------------------------------------------------------===//
176 
177 void AllocOp::getAsmResultNames(
178  function_ref<void(Value, StringRef)> setNameFn) {
179  setNameFn(getResult(), "alloc");
180 }
181 
182 void AllocaOp::getAsmResultNames(
183  function_ref<void(Value, StringRef)> setNameFn) {
184  setNameFn(getResult(), "alloca");
185 }
186 
187 template <typename AllocLikeOp>
188 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
189  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
190  "applies to only alloc or alloca");
191  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
192  if (!memRefType)
193  return op.emitOpError("result must be a memref");
194 
195  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
196  return op.emitOpError("dimension operand count does not equal memref "
197  "dynamic dimension count");
198 
199  unsigned numSymbols = 0;
200  if (!memRefType.getLayout().isIdentity())
201  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
202  if (op.getSymbolOperands().size() != numSymbols)
203  return op.emitOpError("symbol operand count does not equal memref symbol "
204  "count: expected ")
205  << numSymbols << ", got " << op.getSymbolOperands().size();
206 
207  return success();
208 }
209 
210 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
211 
212 LogicalResult AllocaOp::verify() {
213  // An alloca op needs to have an ancestor with an allocation scope trait.
214  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
215  return emitOpError(
216  "requires an ancestor op with AutomaticAllocationScope trait");
217 
218  return verifyAllocLikeOp(*this);
219 }
220 
221 namespace {
222 /// Fold constant dimensions into an alloc like operation.
223 template <typename AllocLikeOp>
224 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
226 
227  LogicalResult matchAndRewrite(AllocLikeOp alloc,
228  PatternRewriter &rewriter) const override {
229  // Check to see if any dimensions operands are constants. If so, we can
230  // substitute and drop them.
231  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
232  APInt constSizeArg;
233  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
234  return false;
235  return constSizeArg.isNonNegative();
236  }))
237  return failure();
238 
239  auto memrefType = alloc.getType();
240 
241  // Ok, we have one or more constant operands. Collect the non-constant ones
242  // and keep track of the resultant memref type to build.
243  SmallVector<int64_t, 4> newShapeConstants;
244  newShapeConstants.reserve(memrefType.getRank());
245  SmallVector<Value, 4> dynamicSizes;
246 
247  unsigned dynamicDimPos = 0;
248  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
249  int64_t dimSize = memrefType.getDimSize(dim);
250  // If this is already static dimension, keep it.
251  if (ShapedType::isStatic(dimSize)) {
252  newShapeConstants.push_back(dimSize);
253  continue;
254  }
255  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
256  APInt constSizeArg;
257  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
258  constSizeArg.isNonNegative()) {
259  // Dynamic shape dimension will be folded.
260  newShapeConstants.push_back(constSizeArg.getZExtValue());
261  } else {
262  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
263  newShapeConstants.push_back(ShapedType::kDynamic);
264  dynamicSizes.push_back(dynamicSize);
265  }
266  dynamicDimPos++;
267  }
268 
269  // Create new memref type (which will have fewer dynamic dimensions).
270  MemRefType newMemRefType =
271  MemRefType::Builder(memrefType).setShape(newShapeConstants);
272  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
273 
274  // Create and insert the alloc op for the new memref.
275  auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
276  dynamicSizes, alloc.getSymbolOperands(),
277  alloc.getAlignmentAttr());
278  // Insert a cast so we have the same type as the old alloc.
279  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
280  return success();
281  }
282 };
283 
284 /// Fold alloc operations with no users or only store and dealloc uses.
285 template <typename T>
286 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
288 
289  LogicalResult matchAndRewrite(T alloc,
290  PatternRewriter &rewriter) const override {
291  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
292  if (auto storeOp = dyn_cast<StoreOp>(op))
293  return storeOp.getValue() == alloc;
294  return !isa<DeallocOp>(op);
295  }))
296  return failure();
297 
298  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
299  rewriter.eraseOp(user);
300 
301  rewriter.eraseOp(alloc);
302  return success();
303  }
304 };
305 } // namespace
306 
307 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
308  MLIRContext *context) {
309  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
310 }
311 
312 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
313  MLIRContext *context) {
314  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
315  context);
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // ReallocOp
320 //===----------------------------------------------------------------------===//
321 
322 LogicalResult ReallocOp::verify() {
323  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
324  MemRefType resultType = getType();
325 
326  // The source memref should have identity layout (or none).
327  if (!sourceType.getLayout().isIdentity())
328  return emitError("unsupported layout for source memref type ")
329  << sourceType;
330 
331  // The result memref should have identity layout (or none).
332  if (!resultType.getLayout().isIdentity())
333  return emitError("unsupported layout for result memref type ")
334  << resultType;
335 
336  // The source memref and the result memref should be in the same memory space.
337  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
338  return emitError("different memory spaces specified for source memref "
339  "type ")
340  << sourceType << " and result memref type " << resultType;
341 
342  // The source memref and the result memref should have the same element type.
343  if (sourceType.getElementType() != resultType.getElementType())
344  return emitError("different element types specified for source memref "
345  "type ")
346  << sourceType << " and result memref type " << resultType;
347 
348  // Verify that we have the dynamic dimension operand when it is needed.
349  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350  return emitError("missing dimension operand for result type ")
351  << resultType;
352  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353  return emitError("unnecessary dimension operand for result type ")
354  << resultType;
355 
356  return success();
357 }
358 
359 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
360  MLIRContext *context) {
361  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // AllocaScopeOp
366 //===----------------------------------------------------------------------===//
367 
369  bool printBlockTerminators = false;
370 
371  p << ' ';
372  if (!getResults().empty()) {
373  p << " -> (" << getResultTypes() << ")";
374  printBlockTerminators = true;
375  }
376  p << ' ';
377  p.printRegion(getBodyRegion(),
378  /*printEntryBlockArgs=*/false,
379  /*printBlockTerminators=*/printBlockTerminators);
380  p.printOptionalAttrDict((*this)->getAttrs());
381 }
382 
383 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
384  // Create a region for the body.
385  result.regions.reserve(1);
386  Region *bodyRegion = result.addRegion();
387 
388  // Parse optional results type list.
389  if (parser.parseOptionalArrowTypeList(result.types))
390  return failure();
391 
392  // Parse the body region.
393  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
394  return failure();
395  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
396  result.location);
397 
398  // Parse the optional attribute list.
399  if (parser.parseOptionalAttrDict(result.attributes))
400  return failure();
401 
402  return success();
403 }
404 
405 void AllocaScopeOp::getSuccessorRegions(
407  if (!point.isParent()) {
408  regions.push_back(RegionSuccessor(getOperation(), getResults()));
409  return;
410  }
411 
412  regions.push_back(RegionSuccessor(&getBodyRegion()));
413 }
414 
415 /// Given an operation, return whether this op is guaranteed to
416 /// allocate an AutomaticAllocationScopeResource
418  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
419  if (!interface)
420  return false;
421  for (auto res : op->getResults()) {
422  if (auto effect =
423  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
424  if (isa<SideEffects::AutomaticAllocationScopeResource>(
425  effect->getResource()))
426  return true;
427  }
428  }
429  return false;
430 }
431 
432 /// Given an operation, return whether this op itself could
433 /// allocate an AutomaticAllocationScopeResource. Note that
434 /// this will not check whether an operation contained within
435 /// the op can allocate.
437  // This op itself doesn't create a stack allocation,
438  // the inner allocation should be handled separately.
440  return false;
441  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
442  if (!interface)
443  return true;
444  for (auto res : op->getResults()) {
445  if (auto effect =
446  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
447  if (isa<SideEffects::AutomaticAllocationScopeResource>(
448  effect->getResource()))
449  return true;
450  }
451  }
452  return false;
453 }
454 
455 /// Return whether this op is the last non terminating op
456 /// in a region. That is to say, it is in a one-block region
457 /// and is only followed by a terminator. This prevents
458 /// extending the lifetime of allocations.
460  return op->getBlock()->mightHaveTerminator() &&
461  op->getNextNode() == op->getBlock()->getTerminator() &&
462  op->getParentRegion()->hasOneBlock();
463 }
464 
465 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
466 /// or it contains no allocation.
467 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
469 
470  LogicalResult matchAndRewrite(AllocaScopeOp op,
471  PatternRewriter &rewriter) const override {
472  bool hasPotentialAlloca =
473  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
474  if (alloc == op)
475  return WalkResult::advance();
477  return WalkResult::interrupt();
478  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
479  return WalkResult::skip();
480  return WalkResult::advance();
481  }).wasInterrupted();
482 
483  // If this contains no potential allocation, it is always legal to
484  // inline. Otherwise, consider two conditions:
485  if (hasPotentialAlloca) {
486  // If the parent isn't an allocation scope, or we are not the last
487  // non-terminator op in the parent, we will extend the lifetime.
488  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
489  return failure();
490  if (!lastNonTerminatorInRegion(op))
491  return failure();
492  }
493 
494  Block *block = &op.getRegion().front();
495  Operation *terminator = block->getTerminator();
496  ValueRange results = terminator->getOperands();
497  rewriter.inlineBlockBefore(block, op);
498  rewriter.replaceOp(op, results);
499  rewriter.eraseOp(terminator);
500  return success();
501  }
502 };
503 
504 /// Move allocations into an allocation scope, if it is legal to
505 /// move them (e.g. their operands are available at the location
506 /// the op would be moved to).
507 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
509 
510  LogicalResult matchAndRewrite(AllocaScopeOp op,
511  PatternRewriter &rewriter) const override {
512 
513  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
514  return failure();
515 
516  Operation *lastParentWithoutScope = op->getParentOp();
517 
518  if (!lastParentWithoutScope ||
519  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
520  return failure();
521 
522  // Only apply to if this is this last non-terminator
523  // op in the block (lest lifetime be extended) of a one
524  // block region
525  if (!lastNonTerminatorInRegion(op) ||
526  !lastNonTerminatorInRegion(lastParentWithoutScope))
527  return failure();
528 
529  while (!lastParentWithoutScope->getParentOp()
531  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
532  if (!lastParentWithoutScope ||
533  !lastNonTerminatorInRegion(lastParentWithoutScope))
534  return failure();
535  }
536  assert(lastParentWithoutScope->getParentOp()
538 
539  Region *containingRegion = nullptr;
540  for (auto &r : lastParentWithoutScope->getRegions()) {
541  if (r.isAncestor(op->getParentRegion())) {
542  assert(containingRegion == nullptr &&
543  "only one region can contain the op");
544  containingRegion = &r;
545  }
546  }
547  assert(containingRegion && "op must be contained in a region");
548 
549  SmallVector<Operation *> toHoist;
550  op->walk([&](Operation *alloc) {
552  return WalkResult::skip();
553 
554  // If any operand is not defined before the location of
555  // lastParentWithoutScope (i.e. where we would hoist to), skip.
556  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
557  return containingRegion->isAncestor(v.getParentRegion());
558  }))
559  return WalkResult::skip();
560  toHoist.push_back(alloc);
561  return WalkResult::advance();
562  });
563 
564  if (toHoist.empty())
565  return failure();
566  rewriter.setInsertionPoint(lastParentWithoutScope);
567  for (auto *op : toHoist) {
568  auto *cloned = rewriter.clone(*op);
569  rewriter.replaceOp(op, cloned->getResults());
570  }
571  return success();
572  }
573 };
574 
575 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
576  MLIRContext *context) {
577  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
578 }
579 
580 //===----------------------------------------------------------------------===//
581 // AssumeAlignmentOp
582 //===----------------------------------------------------------------------===//
583 
584 LogicalResult AssumeAlignmentOp::verify() {
585  if (!llvm::isPowerOf2_32(getAlignment()))
586  return emitOpError("alignment must be power of 2");
587  return success();
588 }
589 
590 void AssumeAlignmentOp::getAsmResultNames(
591  function_ref<void(Value, StringRef)> setNameFn) {
592  setNameFn(getResult(), "assume_align");
593 }
594 
595 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
596  auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
597  if (!source)
598  return {};
599  if (source.getAlignment() != getAlignment())
600  return {};
601  return getMemref();
602 }
603 
604 FailureOr<std::optional<SmallVector<Value>>>
605 AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
606  return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // DistinctObjectsOp
611 //===----------------------------------------------------------------------===//
612 
613 LogicalResult DistinctObjectsOp::verify() {
614  if (getOperandTypes() != getResultTypes())
615  return emitOpError("operand types and result types must match");
616 
617  if (getOperandTypes().empty())
618  return emitOpError("expected at least one operand");
619 
620  return success();
621 }
622 
623 LogicalResult DistinctObjectsOp::inferReturnTypes(
624  MLIRContext * /*context*/, std::optional<Location> /*location*/,
625  ValueRange operands, DictionaryAttr /*attributes*/,
626  OpaqueProperties /*properties*/, RegionRange /*regions*/,
627  SmallVectorImpl<Type> &inferredReturnTypes) {
628  llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
629  return success();
630 }
631 
632 //===----------------------------------------------------------------------===//
633 // CastOp
634 //===----------------------------------------------------------------------===//
635 
636 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
637  setNameFn(getResult(), "cast");
638 }
639 
640 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
641 /// source memref. This is useful to fold a memref.cast into a consuming op
642 /// and implement canonicalization patterns for ops in different dialects that
643 /// may consume the results of memref.cast operations. Such foldable memref.cast
644 /// operations are typically inserted as `view` and `subview` ops are
645 /// canonicalized, to preserve the type compatibility of their uses.
646 ///
647 /// Returns true when all conditions are met:
648 /// 1. source and result are ranked memrefs with strided semantics and same
649 /// element type and rank.
650 /// 2. each of the source's size, offset or stride has more static information
651 /// than the corresponding result's size, offset or stride.
652 ///
653 /// Example 1:
654 /// ```mlir
655 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
656 /// %2 = consumer %1 ... : memref<?x?xf32> ...
657 /// ```
658 ///
659 /// may fold into:
660 ///
661 /// ```mlir
662 /// %2 = consumer %0 ... : memref<8x16xf32> ...
663 /// ```
664 ///
665 /// Example 2:
666 /// ```
667 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
668 /// to memref<?x?xf32>
669 /// consumer %1 : memref<?x?xf32> ...
670 /// ```
671 ///
672 /// may fold into:
673 ///
674 /// ```
675 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
676 /// ```
677 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
678  MemRefType sourceType =
679  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
680  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
681 
682  // Requires ranked MemRefType.
683  if (!sourceType || !resultType)
684  return false;
685 
686  // Requires same elemental type.
687  if (sourceType.getElementType() != resultType.getElementType())
688  return false;
689 
690  // Requires same rank.
691  if (sourceType.getRank() != resultType.getRank())
692  return false;
693 
694  // Only fold casts between strided memref forms.
695  int64_t sourceOffset, resultOffset;
696  SmallVector<int64_t, 4> sourceStrides, resultStrides;
697  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
698  failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
699  return false;
700 
701  // If cast is towards more static sizes along any dimension, don't fold.
702  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
703  auto ss = std::get<0>(it), st = std::get<1>(it);
704  if (ss != st)
705  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
706  return false;
707  }
708 
709  // If cast is towards more static offset along any dimension, don't fold.
710  if (sourceOffset != resultOffset)
711  if (ShapedType::isDynamic(sourceOffset) &&
712  ShapedType::isStatic(resultOffset))
713  return false;
714 
715  // If cast is towards more static strides along any dimension, don't fold.
716  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
717  auto ss = std::get<0>(it), st = std::get<1>(it);
718  if (ss != st)
719  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
720  return false;
721  }
722 
723  return true;
724 }
725 
726 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
727  if (inputs.size() != 1 || outputs.size() != 1)
728  return false;
729  Type a = inputs.front(), b = outputs.front();
730  auto aT = llvm::dyn_cast<MemRefType>(a);
731  auto bT = llvm::dyn_cast<MemRefType>(b);
732 
733  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
734  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
735 
736  if (aT && bT) {
737  if (aT.getElementType() != bT.getElementType())
738  return false;
739  if (aT.getLayout() != bT.getLayout()) {
740  int64_t aOffset, bOffset;
741  SmallVector<int64_t, 4> aStrides, bStrides;
742  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
743  failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
744  aStrides.size() != bStrides.size())
745  return false;
746 
747  // Strides along a dimension/offset are compatible if the value in the
748  // source memref is static and the value in the target memref is the
749  // same. They are also compatible if either one is dynamic (see
750  // description of MemRefCastOp for details).
751  auto checkCompatible = [](int64_t a, int64_t b) {
752  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
753  };
754  if (!checkCompatible(aOffset, bOffset))
755  return false;
756  for (const auto &aStride : enumerate(aStrides))
757  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
758  return false;
759  }
760  if (aT.getMemorySpace() != bT.getMemorySpace())
761  return false;
762 
763  // They must have the same rank, and any specified dimensions must match.
764  if (aT.getRank() != bT.getRank())
765  return false;
766 
767  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
768  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
769  if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
770  aDim != bDim)
771  return false;
772  }
773  return true;
774  } else {
775  if (!aT && !uaT)
776  return false;
777  if (!bT && !ubT)
778  return false;
779  // Unranked to unranked casting is unsupported
780  if (uaT && ubT)
781  return false;
782 
783  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
784  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
785  if (aEltType != bEltType)
786  return false;
787 
788  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
789  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
790  return aMemSpace == bMemSpace;
791  }
792 
793  return false;
794 }
795 
796 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
797  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
798 }
799 
800 FailureOr<std::optional<SmallVector<Value>>>
801 CastOp::bubbleDownCasts(OpBuilder &builder) {
802  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
803 }
804 
805 //===----------------------------------------------------------------------===//
806 // CopyOp
807 //===----------------------------------------------------------------------===//
808 
809 namespace {
810 
811 /// Fold memref.copy(%x, %x).
812 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
814 
815  LogicalResult matchAndRewrite(CopyOp copyOp,
816  PatternRewriter &rewriter) const override {
817  if (copyOp.getSource() != copyOp.getTarget())
818  return failure();
819 
820  rewriter.eraseOp(copyOp);
821  return success();
822  }
823 };
824 
825 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
827 
828  static bool isEmptyMemRef(BaseMemRefType type) {
829  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
830  }
831 
832  LogicalResult matchAndRewrite(CopyOp copyOp,
833  PatternRewriter &rewriter) const override {
834  if (isEmptyMemRef(copyOp.getSource().getType()) ||
835  isEmptyMemRef(copyOp.getTarget().getType())) {
836  rewriter.eraseOp(copyOp);
837  return success();
838  }
839 
840  return failure();
841  }
842 };
843 } // namespace
844 
845 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
846  MLIRContext *context) {
847  results.add<FoldEmptyCopy, FoldSelfCopy>(context);
848 }
849 
850 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
851 /// and element type, the cast can be skipped. Such CastOps only cast the layout
852 /// of the type.
853 static LogicalResult FoldCopyOfCast(CopyOp op) {
854  for (OpOperand &operand : op->getOpOperands()) {
855  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
856  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
857  operand.set(castOp.getOperand());
858  return success();
859  }
860  }
861  return failure();
862 }
863 
864 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
866 
867  /// copy(memrefcast) -> copy
868  return FoldCopyOfCast(*this);
869 }
870 
871 //===----------------------------------------------------------------------===//
872 // DeallocOp
873 //===----------------------------------------------------------------------===//
874 
875 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
877  /// dealloc(memrefcast) -> dealloc
878  return foldMemRefCast(*this);
879 }
880 
881 //===----------------------------------------------------------------------===//
882 // DimOp
883 //===----------------------------------------------------------------------===//
884 
885 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
886  setNameFn(getResult(), "dim");
887 }
888 
889 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
890  int64_t index) {
891  auto loc = result.location;
892  Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
893  build(builder, result, source, indexValue);
894 }
895 
896 std::optional<int64_t> DimOp::getConstantIndex() {
897  return getConstantIntValue(getIndex());
898 }
899 
900 Speculation::Speculatability DimOp::getSpeculatability() {
901  auto constantIndex = getConstantIndex();
902  if (!constantIndex)
904 
905  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
906  if (!rankedSourceType)
908 
909  if (rankedSourceType.getRank() <= constantIndex)
911 
913 }
914 
915 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
916  SetIntLatticeFn setResultRange) {
917  setResultRange(getResult(),
918  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
919 }
920 
921 /// Return a map with key being elements in `vals` and data being number of
922 /// occurences of it. Use std::map, since the `vals` here are strides and the
923 /// dynamic stride value is the same as the tombstone value for
924 /// `DenseMap<int64_t>`.
925 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
926  std::map<int64_t, unsigned> numOccurences;
927  for (auto val : vals)
928  numOccurences[val]++;
929  return numOccurences;
930 }
931 
932 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
933 /// to be a subset of `originalType` with some `1` entries erased, return the
934 /// set of indices that specifies which of the entries of `originalShape` are
935 /// dropped to obtain `reducedShape`.
936 /// This accounts for cases where there are multiple unit-dims, but only a
937 /// subset of those are dropped. For MemRefTypes these can be disambiguated
938 /// using the strides. If a dimension is dropped the stride must be dropped too.
939 static FailureOr<llvm::SmallBitVector>
940 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
941  ArrayRef<OpFoldResult> sizes) {
942  llvm::SmallBitVector unusedDims(originalType.getRank());
943  if (originalType.getRank() == reducedType.getRank())
944  return unusedDims;
945 
946  for (const auto &dim : llvm::enumerate(sizes))
947  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
948  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
949  unusedDims.set(dim.index());
950 
951  // Early exit for the case where the number of unused dims matches the number
952  // of ranks reduced.
953  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
954  originalType.getRank())
955  return unusedDims;
956 
957  SmallVector<int64_t> originalStrides, candidateStrides;
958  int64_t originalOffset, candidateOffset;
959  if (failed(
960  originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
961  failed(
962  reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
963  return failure();
964 
965  // For memrefs, a dimension is truly dropped if its corresponding stride is
966  // also dropped. This is particularly important when more than one of the dims
967  // is 1. Track the number of occurences of the strides in the original type
968  // and the candidate type. For each unused dim that stride should not be
969  // present in the candidate type. Note that there could be multiple dimensions
970  // that have the same size. We dont need to exactly figure out which dim
971  // corresponds to which stride, we just need to verify that the number of
972  // reptitions of a stride in the original + number of unused dims with that
973  // stride == number of repititions of a stride in the candidate.
974  std::map<int64_t, unsigned> currUnaccountedStrides =
975  getNumOccurences(originalStrides);
976  std::map<int64_t, unsigned> candidateStridesNumOccurences =
977  getNumOccurences(candidateStrides);
978  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
979  if (!unusedDims.test(dim))
980  continue;
981  int64_t originalStride = originalStrides[dim];
982  if (currUnaccountedStrides[originalStride] >
983  candidateStridesNumOccurences[originalStride]) {
984  // This dim can be treated as dropped.
985  currUnaccountedStrides[originalStride]--;
986  continue;
987  }
988  if (currUnaccountedStrides[originalStride] ==
989  candidateStridesNumOccurences[originalStride]) {
990  // The stride for this is not dropped. Keep as is.
991  unusedDims.reset(dim);
992  continue;
993  }
994  if (currUnaccountedStrides[originalStride] <
995  candidateStridesNumOccurences[originalStride]) {
996  // This should never happen. Cant have a stride in the reduced rank type
997  // that wasnt in the original one.
998  return failure();
999  }
1000  }
1001 
1002  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1003  originalType.getRank())
1004  return failure();
1005  return unusedDims;
1006 }
1007 
1008 llvm::SmallBitVector SubViewOp::getDroppedDims() {
1009  MemRefType sourceType = getSourceType();
1010  MemRefType resultType = getType();
1011  FailureOr<llvm::SmallBitVector> unusedDims =
1012  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1013  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1014  return *unusedDims;
1015 }
1016 
1017 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1018  // All forms of folding require a known index.
1019  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1020  if (!index)
1021  return {};
1022 
1023  // Folding for unranked types (UnrankedMemRefType) is not supported.
1024  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1025  if (!memrefType)
1026  return {};
1027 
1028  // Out of bound indices produce undefined behavior but are still valid IR.
1029  // Don't choke on them.
1030  int64_t indexVal = index.getInt();
1031  if (indexVal < 0 || indexVal >= memrefType.getRank())
1032  return {};
1033 
1034  // Fold if the shape extent along the given index is known.
1035  if (!memrefType.isDynamicDim(index.getInt())) {
1036  Builder builder(getContext());
1037  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1038  }
1039 
1040  // The size at the given index is now known to be a dynamic size.
1041  unsigned unsignedIndex = index.getValue().getZExtValue();
1042 
1043  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1044  Operation *definingOp = getSource().getDefiningOp();
1045 
1046  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1047  return *(alloc.getDynamicSizes().begin() +
1048  memrefType.getDynamicDimIndex(unsignedIndex));
1049 
1050  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1051  return *(alloca.getDynamicSizes().begin() +
1052  memrefType.getDynamicDimIndex(unsignedIndex));
1053 
1054  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1055  return *(view.getDynamicSizes().begin() +
1056  memrefType.getDynamicDimIndex(unsignedIndex));
1057 
1058  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1059  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1060  unsigned resultIndex = 0;
1061  unsigned sourceRank = subview.getSourceType().getRank();
1062  unsigned sourceIndex = 0;
1063  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1064  if (unusedDims.test(i))
1065  continue;
1066  if (resultIndex == unsignedIndex) {
1067  sourceIndex = i;
1068  break;
1069  }
1070  resultIndex++;
1071  }
1072  assert(subview.isDynamicSize(sourceIndex) &&
1073  "expected dynamic subview size");
1074  return subview.getDynamicSize(sourceIndex);
1075  }
1076 
1077  if (auto sizeInterface =
1078  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1079  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1080  "Expected dynamic subview size");
1081  return sizeInterface.getDynamicSize(unsignedIndex);
1082  }
1083 
1084  // dim(memrefcast) -> dim
1085  if (succeeded(foldMemRefCast(*this)))
1086  return getResult();
1087 
1088  return {};
1089 }
1090 
1091 namespace {
1092 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1093 /// operand.
1094 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1096 
1097  LogicalResult matchAndRewrite(DimOp dim,
1098  PatternRewriter &rewriter) const override {
1099  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1100 
1101  if (!reshape)
1102  return rewriter.notifyMatchFailure(
1103  dim, "Dim op is not defined by a reshape op.");
1104 
1105  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1106  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1107  // cheaply check that either of the following conditions hold:
1108  // 1. dim.getIndex() is defined in the same block as reshape but before
1109  // reshape.
1110  // 2. dim.getIndex() is defined in a parent block of
1111  // reshape.
1112 
1113  // Check condition 1
1114  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1115  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1116  if (reshape->isBeforeInBlock(definingOp)) {
1117  return rewriter.notifyMatchFailure(
1118  dim,
1119  "dim.getIndex is not defined before reshape in the same block.");
1120  }
1121  } // else dim.getIndex is a block argument to reshape->getBlock and
1122  // dominates reshape
1123  } // Check condition 2
1124  else if (dim->getBlock() != reshape->getBlock() &&
1125  !dim.getIndex().getParentRegion()->isProperAncestor(
1126  reshape->getParentRegion())) {
1127  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1128  // already know dim.getIndex() dominates reshape without calling
1129  // `isProperAncestor`
1130  return rewriter.notifyMatchFailure(
1131  dim, "dim.getIndex does not dominate reshape.");
1132  }
1133 
1134  // Place the load directly after the reshape to ensure that the shape memref
1135  // was not mutated.
1136  rewriter.setInsertionPointAfter(reshape);
1137  Location loc = dim.getLoc();
1138  Value load =
1139  LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1140  if (load.getType() != dim.getType())
1141  load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1142  rewriter.replaceOp(dim, load);
1143  return success();
1144  }
1145 };
1146 
1147 } // namespace
1148 
1149 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1150  MLIRContext *context) {
1151  results.add<DimOfMemRefReshape>(context);
1152 }
1153 
1154 // ---------------------------------------------------------------------------
1155 // DmaStartOp
1156 // ---------------------------------------------------------------------------
1157 
1158 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1159  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1160  ValueRange destIndices, Value numElements,
1161  Value tagMemRef, ValueRange tagIndices, Value stride,
1162  Value elementsPerStride) {
1163  result.addOperands(srcMemRef);
1164  result.addOperands(srcIndices);
1165  result.addOperands(destMemRef);
1166  result.addOperands(destIndices);
1167  result.addOperands({numElements, tagMemRef});
1168  result.addOperands(tagIndices);
1169  if (stride)
1170  result.addOperands({stride, elementsPerStride});
1171 }
1172 
1174  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1175  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1176  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1177  if (isStrided())
1178  p << ", " << getStride() << ", " << getNumElementsPerStride();
1179 
1180  p.printOptionalAttrDict((*this)->getAttrs());
1181  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1182  << ", " << getTagMemRef().getType();
1183 }
1184 
1185 // Parse DmaStartOp.
1186 // Ex:
1187 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1188 // %tag[%index], %stride, %num_elt_per_stride :
1189 // : memref<3076 x f32, 0>,
1190 // memref<1024 x f32, 2>,
1191 // memref<1 x i32>
1192 //
1193 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1194  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1196  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1198  OpAsmParser::UnresolvedOperand numElementsInfo;
1199  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1202 
1203  SmallVector<Type, 3> types;
1204  auto indexType = parser.getBuilder().getIndexType();
1205 
1206  // Parse and resolve the following list of operands:
1207  // *) source memref followed by its indices (in square brackets).
1208  // *) destination memref followed by its indices (in square brackets).
1209  // *) dma size in KiB.
1210  if (parser.parseOperand(srcMemRefInfo) ||
1211  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1212  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1213  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1214  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1215  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1216  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1217  return failure();
1218 
1219  // Parse optional stride and elements per stride.
1220  if (parser.parseTrailingOperandList(strideInfo))
1221  return failure();
1222 
1223  bool isStrided = strideInfo.size() == 2;
1224  if (!strideInfo.empty() && !isStrided) {
1225  return parser.emitError(parser.getNameLoc(),
1226  "expected two stride related operands");
1227  }
1228 
1229  if (parser.parseColonTypeList(types))
1230  return failure();
1231  if (types.size() != 3)
1232  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1233 
1234  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1235  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1236  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1237  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1238  // size should be an index.
1239  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1240  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1241  // tag indices should be index.
1242  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1243  return failure();
1244 
1245  if (isStrided) {
1246  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1247  return failure();
1248  }
1249 
1250  return success();
1251 }
1252 
1253 LogicalResult DmaStartOp::verify() {
1254  unsigned numOperands = getNumOperands();
1255 
1256  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1257  // the number of elements.
1258  if (numOperands < 4)
1259  return emitOpError("expected at least 4 operands");
1260 
1261  // Check types of operands. The order of these calls is important: the later
1262  // calls rely on some type properties to compute the operand position.
1263  // 1. Source memref.
1264  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1265  return emitOpError("expected source to be of memref type");
1266  if (numOperands < getSrcMemRefRank() + 4)
1267  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1268  << " operands";
1269  if (!getSrcIndices().empty() &&
1270  !llvm::all_of(getSrcIndices().getTypes(),
1271  [](Type t) { return t.isIndex(); }))
1272  return emitOpError("expected source indices to be of index type");
1273 
1274  // 2. Destination memref.
1275  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1276  return emitOpError("expected destination to be of memref type");
1277  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1278  if (numOperands < numExpectedOperands)
1279  return emitOpError() << "expected at least " << numExpectedOperands
1280  << " operands";
1281  if (!getDstIndices().empty() &&
1282  !llvm::all_of(getDstIndices().getTypes(),
1283  [](Type t) { return t.isIndex(); }))
1284  return emitOpError("expected destination indices to be of index type");
1285 
1286  // 3. Number of elements.
1287  if (!getNumElements().getType().isIndex())
1288  return emitOpError("expected num elements to be of index type");
1289 
1290  // 4. Tag memref.
1291  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1292  return emitOpError("expected tag to be of memref type");
1293  numExpectedOperands += getTagMemRefRank();
1294  if (numOperands < numExpectedOperands)
1295  return emitOpError() << "expected at least " << numExpectedOperands
1296  << " operands";
1297  if (!getTagIndices().empty() &&
1298  !llvm::all_of(getTagIndices().getTypes(),
1299  [](Type t) { return t.isIndex(); }))
1300  return emitOpError("expected tag indices to be of index type");
1301 
1302  // Optional stride-related operands must be either both present or both
1303  // absent.
1304  if (numOperands != numExpectedOperands &&
1305  numOperands != numExpectedOperands + 2)
1306  return emitOpError("incorrect number of operands");
1307 
1308  // 5. Strides.
1309  if (isStrided()) {
1310  if (!getStride().getType().isIndex() ||
1311  !getNumElementsPerStride().getType().isIndex())
1312  return emitOpError(
1313  "expected stride and num elements per stride to be of type index");
1314  }
1315 
1316  return success();
1317 }
1318 
1319 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1320  SmallVectorImpl<OpFoldResult> &results) {
1321  /// dma_start(memrefcast) -> dma_start
1322  return foldMemRefCast(*this);
1323 }
1324 
1325 // ---------------------------------------------------------------------------
1326 // DmaWaitOp
1327 // ---------------------------------------------------------------------------
1328 
1329 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1330  SmallVectorImpl<OpFoldResult> &results) {
1331  /// dma_wait(memrefcast) -> dma_wait
1332  return foldMemRefCast(*this);
1333 }
1334 
1335 LogicalResult DmaWaitOp::verify() {
1336  // Check that the number of tag indices matches the tagMemRef rank.
1337  unsigned numTagIndices = getTagIndices().size();
1338  unsigned tagMemRefRank = getTagMemRefRank();
1339  if (numTagIndices != tagMemRefRank)
1340  return emitOpError() << "expected tagIndices to have the same number of "
1341  "elements as the tagMemRef rank, expected "
1342  << tagMemRefRank << ", but got " << numTagIndices;
1343  return success();
1344 }
1345 
1346 //===----------------------------------------------------------------------===//
1347 // ExtractAlignedPointerAsIndexOp
1348 //===----------------------------------------------------------------------===//
1349 
1350 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1351  function_ref<void(Value, StringRef)> setNameFn) {
1352  setNameFn(getResult(), "intptr");
1353 }
1354 
1355 //===----------------------------------------------------------------------===//
1356 // ExtractStridedMetadataOp
1357 //===----------------------------------------------------------------------===//
1358 
1359 /// The number and type of the results are inferred from the
1360 /// shape of the source.
1361 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1362  MLIRContext *context, std::optional<Location> location,
1363  ExtractStridedMetadataOp::Adaptor adaptor,
1364  SmallVectorImpl<Type> &inferredReturnTypes) {
1365  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1366  if (!sourceType)
1367  return failure();
1368 
1369  unsigned sourceRank = sourceType.getRank();
1370  IndexType indexType = IndexType::get(context);
1371  auto memrefType =
1372  MemRefType::get({}, sourceType.getElementType(),
1373  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1374  // Base.
1375  inferredReturnTypes.push_back(memrefType);
1376  // Offset.
1377  inferredReturnTypes.push_back(indexType);
1378  // Sizes and strides.
1379  for (unsigned i = 0; i < sourceRank * 2; ++i)
1380  inferredReturnTypes.push_back(indexType);
1381  return success();
1382 }
1383 
1384 void ExtractStridedMetadataOp::getAsmResultNames(
1385  function_ref<void(Value, StringRef)> setNameFn) {
1386  setNameFn(getBaseBuffer(), "base_buffer");
1387  setNameFn(getOffset(), "offset");
1388  // For multi-result to work properly with pretty names and packed syntax `x:3`
1389  // we can only give a pretty name to the first value in the pack.
1390  if (!getSizes().empty()) {
1391  setNameFn(getSizes().front(), "sizes");
1392  setNameFn(getStrides().front(), "strides");
1393  }
1394 }
1395 
1396 /// Helper function to perform the replacement of all constant uses of `values`
1397 /// by a materialized constant extracted from `maybeConstants`.
1398 /// `values` and `maybeConstants` are expected to have the same size.
1399 template <typename Container>
1400 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1401  Container values,
1402  ArrayRef<OpFoldResult> maybeConstants) {
1403  assert(values.size() == maybeConstants.size() &&
1404  " expected values and maybeConstants of the same size");
1405  bool atLeastOneReplacement = false;
1406  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1407  // Don't materialize a constant if there are no uses: this would indice
1408  // infinite loops in the driver.
1409  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1410  continue;
1411  assert(isa<Attribute>(maybeConstant) &&
1412  "The constified value should be either unchanged (i.e., == result) "
1413  "or a constant");
1414  Value constantVal = arith::ConstantIndexOp::create(
1415  rewriter, loc,
1416  llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1417  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1418  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1419  // yet.
1420  op->replaceUsesOfWith(result, constantVal);
1421  atLeastOneReplacement = true;
1422  }
1423  }
1424  return atLeastOneReplacement;
1425 }
1426 
1427 LogicalResult
1428 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1429  SmallVectorImpl<OpFoldResult> &results) {
1430  OpBuilder builder(*this);
1431 
1432  bool atLeastOneReplacement = replaceConstantUsesOf(
1433  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1434  getConstifiedMixedOffset());
1435  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1436  getConstifiedMixedSizes());
1437  atLeastOneReplacement |= replaceConstantUsesOf(
1438  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1439 
1440  // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
1441  if (auto prev = getSource().getDefiningOp<CastOp>())
1442  if (isa<MemRefType>(prev.getSource().getType())) {
1443  getSourceMutable().assign(prev.getSource());
1444  atLeastOneReplacement = true;
1445  }
1446 
1447  return success(atLeastOneReplacement);
1448 }
1449 
1450 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1451  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1452  constifyIndexValues(values, getSource().getType().getShape());
1453  return values;
1454 }
1455 
1457 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1458  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1459  SmallVector<int64_t> staticValues;
1460  int64_t unused;
1461  LogicalResult status =
1462  getSource().getType().getStridesAndOffset(staticValues, unused);
1463  (void)status;
1464  assert(succeeded(status) && "could not get strides from type");
1465  constifyIndexValues(values, staticValues);
1466  return values;
1467 }
1468 
1469 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1470  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1471  SmallVector<OpFoldResult> values(1, offsetOfr);
1472  SmallVector<int64_t> staticValues, unused;
1473  int64_t offset;
1474  LogicalResult status =
1475  getSource().getType().getStridesAndOffset(unused, offset);
1476  (void)status;
1477  assert(succeeded(status) && "could not get offset from type");
1478  staticValues.push_back(offset);
1479  constifyIndexValues(values, staticValues);
1480  return values[0];
1481 }
1482 
1483 //===----------------------------------------------------------------------===//
1484 // GenericAtomicRMWOp
1485 //===----------------------------------------------------------------------===//
1486 
1487 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1488  Value memref, ValueRange ivs) {
1489  OpBuilder::InsertionGuard g(builder);
1490  result.addOperands(memref);
1491  result.addOperands(ivs);
1492 
1493  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1494  Type elementType = memrefType.getElementType();
1495  result.addTypes(elementType);
1496 
1497  Region *bodyRegion = result.addRegion();
1498  builder.createBlock(bodyRegion);
1499  bodyRegion->addArgument(elementType, memref.getLoc());
1500  }
1501 }
1502 
1503 LogicalResult GenericAtomicRMWOp::verify() {
1504  auto &body = getRegion();
1505  if (body.getNumArguments() != 1)
1506  return emitOpError("expected single number of entry block arguments");
1507 
1508  if (getResult().getType() != body.getArgument(0).getType())
1509  return emitOpError("expected block argument of the same type result type");
1510 
1511  bool hasSideEffects =
1512  body.walk([&](Operation *nestedOp) {
1513  if (isMemoryEffectFree(nestedOp))
1514  return WalkResult::advance();
1515  nestedOp->emitError(
1516  "body of 'memref.generic_atomic_rmw' should contain "
1517  "only operations with no side effects");
1518  return WalkResult::interrupt();
1519  })
1520  .wasInterrupted();
1521  return hasSideEffects ? failure() : success();
1522 }
1523 
1524 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1525  OperationState &result) {
1527  Type memrefType;
1529 
1530  Type indexType = parser.getBuilder().getIndexType();
1531  if (parser.parseOperand(memref) ||
1533  parser.parseColonType(memrefType) ||
1534  parser.resolveOperand(memref, memrefType, result.operands) ||
1535  parser.resolveOperands(ivs, indexType, result.operands))
1536  return failure();
1537 
1538  Region *body = result.addRegion();
1539  if (parser.parseRegion(*body, {}) ||
1540  parser.parseOptionalAttrDict(result.attributes))
1541  return failure();
1542  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1543  return success();
1544 }
1545 
1547  p << ' ' << getMemref() << "[" << getIndices()
1548  << "] : " << getMemref().getType() << ' ';
1549  p.printRegion(getRegion());
1550  p.printOptionalAttrDict((*this)->getAttrs());
1551 }
1552 
1553 //===----------------------------------------------------------------------===//
1554 // AtomicYieldOp
1555 //===----------------------------------------------------------------------===//
1556 
1557 LogicalResult AtomicYieldOp::verify() {
1558  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1559  Type resultType = getResult().getType();
1560  if (parentType != resultType)
1561  return emitOpError() << "types mismatch between yield op: " << resultType
1562  << " and its parent: " << parentType;
1563  return success();
1564 }
1565 
1566 //===----------------------------------------------------------------------===//
1567 // GlobalOp
1568 //===----------------------------------------------------------------------===//
1569 
1571  TypeAttr type,
1572  Attribute initialValue) {
1573  p << type;
1574  if (!op.isExternal()) {
1575  p << " = ";
1576  if (op.isUninitialized())
1577  p << "uninitialized";
1578  else
1579  p.printAttributeWithoutType(initialValue);
1580  }
1581 }
1582 
1583 static ParseResult
1585  Attribute &initialValue) {
1586  Type type;
1587  if (parser.parseType(type))
1588  return failure();
1589 
1590  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1591  if (!memrefType || !memrefType.hasStaticShape())
1592  return parser.emitError(parser.getNameLoc())
1593  << "type should be static shaped memref, but got " << type;
1594  typeAttr = TypeAttr::get(type);
1595 
1596  if (parser.parseOptionalEqual())
1597  return success();
1598 
1599  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1600  initialValue = UnitAttr::get(parser.getContext());
1601  return success();
1602  }
1603 
1604  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1605  if (parser.parseAttribute(initialValue, tensorType))
1606  return failure();
1607  if (!llvm::isa<ElementsAttr>(initialValue))
1608  return parser.emitError(parser.getNameLoc())
1609  << "initial value should be a unit or elements attribute";
1610  return success();
1611 }
1612 
1613 LogicalResult GlobalOp::verify() {
1614  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1615  if (!memrefType || !memrefType.hasStaticShape())
1616  return emitOpError("type should be static shaped memref, but got ")
1617  << getType();
1618 
1619  // Verify that the initial value, if present, is either a unit attribute or
1620  // an elements attribute.
1621  if (getInitialValue().has_value()) {
1622  Attribute initValue = getInitialValue().value();
1623  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1624  return emitOpError("initial value should be a unit or elements "
1625  "attribute, but got ")
1626  << initValue;
1627 
1628  // Check that the type of the initial value is compatible with the type of
1629  // the global variable.
1630  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1631  // Check the element types match.
1632  auto initElementType =
1633  cast<TensorType>(elementsAttr.getType()).getElementType();
1634  auto memrefElementType = memrefType.getElementType();
1635 
1636  if (initElementType != memrefElementType)
1637  return emitOpError("initial value element expected to be of type ")
1638  << memrefElementType << ", but was of type " << initElementType;
1639 
1640  // Check the shapes match, given that memref globals can only produce
1641  // statically shaped memrefs and elements literal type must have a static
1642  // shape we can assume both types are shaped.
1643  auto initShape = elementsAttr.getShapedType().getShape();
1644  auto memrefShape = memrefType.getShape();
1645  if (initShape != memrefShape)
1646  return emitOpError("initial value shape expected to be ")
1647  << memrefShape << " but was " << initShape;
1648  }
1649  }
1650 
1651  // TODO: verify visibility for declarations.
1652  return success();
1653 }
1654 
1655 ElementsAttr GlobalOp::getConstantInitValue() {
1656  auto initVal = getInitialValue();
1657  if (getConstant() && initVal.has_value())
1658  return llvm::cast<ElementsAttr>(initVal.value());
1659  return {};
1660 }
1661 
1662 //===----------------------------------------------------------------------===//
1663 // GetGlobalOp
1664 //===----------------------------------------------------------------------===//
1665 
1666 LogicalResult
1667 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1668  // Verify that the result type is same as the type of the referenced
1669  // memref.global op.
1670  auto global =
1671  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1672  if (!global)
1673  return emitOpError("'")
1674  << getName() << "' does not reference a valid global memref";
1675 
1676  Type resultType = getResult().getType();
1677  if (global.getType() != resultType)
1678  return emitOpError("result type ")
1679  << resultType << " does not match type " << global.getType()
1680  << " of the global memref @" << getName();
1681  return success();
1682 }
1683 
1684 //===----------------------------------------------------------------------===//
1685 // LoadOp
1686 //===----------------------------------------------------------------------===//
1687 
1688 LogicalResult LoadOp::verify() {
1689  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1690  return emitOpError("incorrect number of indices for load, expected ")
1691  << getMemRefType().getRank() << " but got " << getIndices().size();
1692  }
1693  return success();
1694 }
1695 
1696 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1697  /// load(memrefcast) -> load
1698  if (succeeded(foldMemRefCast(*this)))
1699  return getResult();
1700  return OpFoldResult();
1701 }
1702 
1703 FailureOr<std::optional<SmallVector<Value>>>
1704 LoadOp::bubbleDownCasts(OpBuilder &builder) {
1705  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
1706  getResult());
1707 }
1708 
1709 //===----------------------------------------------------------------------===//
1710 // MemorySpaceCastOp
1711 //===----------------------------------------------------------------------===//
1712 
1713 void MemorySpaceCastOp::getAsmResultNames(
1714  function_ref<void(Value, StringRef)> setNameFn) {
1715  setNameFn(getResult(), "memspacecast");
1716 }
1717 
1718 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1719  if (inputs.size() != 1 || outputs.size() != 1)
1720  return false;
1721  Type a = inputs.front(), b = outputs.front();
1722  auto aT = llvm::dyn_cast<MemRefType>(a);
1723  auto bT = llvm::dyn_cast<MemRefType>(b);
1724 
1725  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1726  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1727 
1728  if (aT && bT) {
1729  if (aT.getElementType() != bT.getElementType())
1730  return false;
1731  if (aT.getLayout() != bT.getLayout())
1732  return false;
1733  if (aT.getShape() != bT.getShape())
1734  return false;
1735  return true;
1736  }
1737  if (uaT && ubT) {
1738  return uaT.getElementType() == ubT.getElementType();
1739  }
1740  return false;
1741 }
1742 
1743 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1744  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1745  // t2)
1746  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1747  getSourceMutable().assign(parentCast.getSource());
1748  return getResult();
1749  }
1750  return Value{};
1751 }
1752 
1753 TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
1754  return getSource();
1755 }
1756 
1757 TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
1758  return getDest();
1759 }
1760 
1761 bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1762  PtrLikeTypeInterface src) {
1763  return isa<BaseMemRefType>(tgt) &&
1764  tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1765 }
1766 
1767 MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1768  OpBuilder &b, PtrLikeTypeInterface tgt,
1770  assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
1771  return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
1772 }
1773 
1774 /// The only cast we recognize as promotable is to the generic space.
1775 bool MemorySpaceCastOp::isSourcePromotable() {
1776  return getDest().getType().getMemorySpace() == nullptr;
1777 }
1778 
1779 //===----------------------------------------------------------------------===//
1780 // PrefetchOp
1781 //===----------------------------------------------------------------------===//
1782 
1784  p << " " << getMemref() << '[';
1786  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1787  p << ", locality<" << getLocalityHint();
1788  p << ">, " << (getIsDataCache() ? "data" : "instr");
1790  (*this)->getAttrs(),
1791  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1792  p << " : " << getMemRefType();
1793 }
1794 
1795 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1796  OpAsmParser::UnresolvedOperand memrefInfo;
1798  IntegerAttr localityHint;
1799  MemRefType type;
1800  StringRef readOrWrite, cacheType;
1801 
1802  auto indexTy = parser.getBuilder().getIndexType();
1803  auto i32Type = parser.getBuilder().getIntegerType(32);
1804  if (parser.parseOperand(memrefInfo) ||
1805  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1806  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1807  parser.parseComma() || parser.parseKeyword("locality") ||
1808  parser.parseLess() ||
1809  parser.parseAttribute(localityHint, i32Type, "localityHint",
1810  result.attributes) ||
1811  parser.parseGreater() || parser.parseComma() ||
1812  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1813  parser.resolveOperand(memrefInfo, type, result.operands) ||
1814  parser.resolveOperands(indexInfo, indexTy, result.operands))
1815  return failure();
1816 
1817  if (readOrWrite != "read" && readOrWrite != "write")
1818  return parser.emitError(parser.getNameLoc(),
1819  "rw specifier has to be 'read' or 'write'");
1820  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1821  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1822 
1823  if (cacheType != "data" && cacheType != "instr")
1824  return parser.emitError(parser.getNameLoc(),
1825  "cache type has to be 'data' or 'instr'");
1826 
1827  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1828  parser.getBuilder().getBoolAttr(cacheType == "data"));
1829 
1830  return success();
1831 }
1832 
1833 LogicalResult PrefetchOp::verify() {
1834  if (getNumOperands() != 1 + getMemRefType().getRank())
1835  return emitOpError("too few indices");
1836 
1837  return success();
1838 }
1839 
1840 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1841  SmallVectorImpl<OpFoldResult> &results) {
1842  // prefetch(memrefcast) -> prefetch
1843  return foldMemRefCast(*this);
1844 }
1845 
1846 //===----------------------------------------------------------------------===//
1847 // RankOp
1848 //===----------------------------------------------------------------------===//
1849 
1850 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1851  // Constant fold rank when the rank of the operand is known.
1852  auto type = getOperand().getType();
1853  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1854  if (shapedType && shapedType.hasRank())
1855  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1856  return IntegerAttr();
1857 }
1858 
1859 //===----------------------------------------------------------------------===//
1860 // ReinterpretCastOp
1861 //===----------------------------------------------------------------------===//
1862 
1863 void ReinterpretCastOp::getAsmResultNames(
1864  function_ref<void(Value, StringRef)> setNameFn) {
1865  setNameFn(getResult(), "reinterpret_cast");
1866 }
1867 
1868 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1869 /// `staticSizes` and `staticStrides` are automatically filled with
1870 /// source-memref-rank sentinel values that encode dynamic entries.
1871 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1872  MemRefType resultType, Value source,
1873  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1874  ArrayRef<OpFoldResult> strides,
1875  ArrayRef<NamedAttribute> attrs) {
1876  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1877  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1878  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1879  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1880  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1881  result.addAttributes(attrs);
1882  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1883  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1884  b.getDenseI64ArrayAttr(staticSizes),
1885  b.getDenseI64ArrayAttr(staticStrides));
1886 }
1887 
1888 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1889  Value source, OpFoldResult offset,
1890  ArrayRef<OpFoldResult> sizes,
1891  ArrayRef<OpFoldResult> strides,
1892  ArrayRef<NamedAttribute> attrs) {
1893  auto sourceType = cast<BaseMemRefType>(source.getType());
1894  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1895  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1896  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1897  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1898  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1899  auto stridedLayout = StridedLayoutAttr::get(
1900  b.getContext(), staticOffsets.front(), staticStrides);
1901  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1902  stridedLayout, sourceType.getMemorySpace());
1903  build(b, result, resultType, source, offset, sizes, strides, attrs);
1904 }
1905 
1906 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1907  MemRefType resultType, Value source,
1908  int64_t offset, ArrayRef<int64_t> sizes,
1909  ArrayRef<int64_t> strides,
1910  ArrayRef<NamedAttribute> attrs) {
1911  SmallVector<OpFoldResult> sizeValues =
1912  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1913  return b.getI64IntegerAttr(v);
1914  }));
1915  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1916  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1917  return b.getI64IntegerAttr(v);
1918  }));
1919  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1920  strideValues, attrs);
1921 }
1922 
1923 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1924  MemRefType resultType, Value source, Value offset,
1925  ValueRange sizes, ValueRange strides,
1926  ArrayRef<NamedAttribute> attrs) {
1927  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1928  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1929  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1930  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1931  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1932 }
1933 
1934 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1935 // completed automatically, like we have for subview and extract_slice.
1936 LogicalResult ReinterpretCastOp::verify() {
1937  // The source and result memrefs should be in the same memory space.
1938  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1939  auto resultType = llvm::cast<MemRefType>(getType());
1940  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1941  return emitError("different memory spaces specified for source type ")
1942  << srcType << " and result memref type " << resultType;
1943  if (srcType.getElementType() != resultType.getElementType())
1944  return emitError("different element types specified for source type ")
1945  << srcType << " and result memref type " << resultType;
1946 
1947  // Match sizes in result memref type and in static_sizes attribute.
1948  for (auto [idx, resultSize, expectedSize] :
1949  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1950  if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1951  return emitError("expected result type with size = ")
1952  << (ShapedType::isDynamic(expectedSize)
1953  ? std::string("dynamic")
1954  : std::to_string(expectedSize))
1955  << " instead of " << resultSize << " in dim = " << idx;
1956  }
1957 
1958  // Match offset and strides in static_offset and static_strides attributes. If
1959  // result memref type has no affine map specified, this will assume an
1960  // identity layout.
1961  int64_t resultOffset;
1962  SmallVector<int64_t, 4> resultStrides;
1963  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1964  return emitError("expected result type to have strided layout but found ")
1965  << resultType;
1966 
1967  // Match offset in result memref type and in static_offsets attribute.
1968  int64_t expectedOffset = getStaticOffsets().front();
1969  if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1970  return emitError("expected result type with offset = ")
1971  << (ShapedType::isDynamic(expectedOffset)
1972  ? std::string("dynamic")
1973  : std::to_string(expectedOffset))
1974  << " instead of " << resultOffset;
1975 
1976  // Match strides in result memref type and in static_strides attribute.
1977  for (auto [idx, resultStride, expectedStride] :
1978  llvm::enumerate(resultStrides, getStaticStrides())) {
1979  if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1980  return emitError("expected result type with stride = ")
1981  << (ShapedType::isDynamic(expectedStride)
1982  ? std::string("dynamic")
1983  : std::to_string(expectedStride))
1984  << " instead of " << resultStride << " in dim = " << idx;
1985  }
1986 
1987  return success();
1988 }
1989 
1990 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1991  Value src = getSource();
1992  auto getPrevSrc = [&]() -> Value {
1993  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1994  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1995  return prev.getSource();
1996 
1997  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1998  if (auto prev = src.getDefiningOp<CastOp>())
1999  return prev.getSource();
2000 
2001  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
2002  // are 0.
2003  if (auto prev = src.getDefiningOp<SubViewOp>())
2004  if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
2005  return prev.getSource();
2006 
2007  return nullptr;
2008  };
2009 
2010  if (auto prevSrc = getPrevSrc()) {
2011  getSourceMutable().assign(prevSrc);
2012  return getResult();
2013  }
2014 
2015  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
2016  if (ShapedType::isStaticShape(getType().getShape()) &&
2017  src.getType() == getType() && getStaticOffsets().front() == 0) {
2018  return src;
2019  }
2020 
2021  return nullptr;
2022 }
2023 
2024 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2026  constifyIndexValues(values, getType().getShape());
2027  return values;
2028 }
2029 
2030 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2031  SmallVector<OpFoldResult> values = getMixedStrides();
2032  SmallVector<int64_t> staticValues;
2033  int64_t unused;
2034  LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
2035  (void)status;
2036  assert(succeeded(status) && "could not get strides from type");
2037  constifyIndexValues(values, staticValues);
2038  return values;
2039 }
2040 
2041 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2042  SmallVector<OpFoldResult> values = getMixedOffsets();
2043  assert(values.size() == 1 &&
2044  "reinterpret_cast must have one and only one offset");
2045  SmallVector<int64_t> staticValues, unused;
2046  int64_t offset;
2047  LogicalResult status = getType().getStridesAndOffset(unused, offset);
2048  (void)status;
2049  assert(succeeded(status) && "could not get offset from type");
2050  staticValues.push_back(offset);
2051  constifyIndexValues(values, staticValues);
2052  return values[0];
2053 }
2054 
2055 namespace {
2056 /// Replace the sequence:
2057 /// ```
2058 /// base, offset, sizes, strides = extract_strided_metadata src
2059 /// dst = reinterpret_cast base to offset, sizes, strides
2060 /// ```
2061 /// With
2062 ///
2063 /// ```
2064 /// dst = memref.cast src
2065 /// ```
2066 ///
2067 /// Note: The cast operation is only inserted when the type of dst and src
2068 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
2069 ///
2070 /// This pattern also matches when the offset, sizes, and strides don't come
2071 /// directly from the `extract_strided_metadata`'s results but it can be
2072 /// statically proven that they would hold the same values.
2073 ///
2074 /// For instance, the following sequence would be replaced:
2075 /// ```
2076 /// base, offset, sizes, strides =
2077 /// extract_strided_metadata memref : memref<3x4xty>
2078 /// dst = reinterpret_cast base to 0, [3, 4], strides
2079 /// ```
2080 /// Because we know (thanks to the type of the input memref) that variable
2081 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
2082 ///
2083 /// Similarly, the following sequence would be replaced:
2084 /// ```
2085 /// c0 = arith.constant 0
2086 /// c4 = arith.constant 4
2087 /// base, offset, sizes, strides =
2088 /// extract_strided_metadata memref : memref<3x4xty>
2089 /// dst = reinterpret_cast base to c0, [3, c4], strides
2090 /// ```
2091 /// Because we know that `offset`and `c0` will hold 0
2092 /// and `c4` will hold 4.
2093 ///
2094 /// If the pattern above does not match, the input of the
2095 /// extract_strided_metadata is always folded into the input of the
2096 /// reinterpret_cast operator. This allows for dead code elimination to get rid
2097 /// of the extract_strided_metadata in some cases.
2098 struct ReinterpretCastOpExtractStridedMetadataFolder
2099  : public OpRewritePattern<ReinterpretCastOp> {
2100 public:
2102 
2103  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2104  PatternRewriter &rewriter) const override {
2105  auto extractStridedMetadata =
2106  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2107  if (!extractStridedMetadata)
2108  return failure();
2109 
2110  // Check if the reinterpret cast reconstructs a memref with the exact same
2111  // properties as the extract strided metadata.
2112  auto isReinterpretCastNoop = [&]() -> bool {
2113  // First, check that the strides are the same.
2114  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2115  op.getConstifiedMixedStrides()))
2116  return false;
2117 
2118  // Second, check the sizes.
2119  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2120  op.getConstifiedMixedSizes()))
2121  return false;
2122 
2123  // Finally, check the offset.
2124  assert(op.getMixedOffsets().size() == 1 &&
2125  "reinterpret_cast with more than one offset should have been "
2126  "rejected by the verifier");
2127  return extractStridedMetadata.getConstifiedMixedOffset() ==
2128  op.getConstifiedMixedOffset();
2129  };
2130 
2131  if (!isReinterpretCastNoop()) {
2132  // If the extract_strided_metadata / reinterpret_cast pair can't be
2133  // completely folded, then we could fold the input of the
2134  // extract_strided_metadata into the input of the reinterpret_cast
2135  // input. For some cases (e.g., static dimensions) the
2136  // the extract_strided_metadata is eliminated by dead code elimination.
2137  //
2138  // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2139  //
2140  // We can always fold the input of a extract_strided_metadata operator
2141  // to the input of a reinterpret_cast operator, because they point to
2142  // the same memory. Note that the reinterpret_cast does not use the
2143  // layout of its input memref, only its base memory pointer which is
2144  // the same as the base pointer returned by the extract_strided_metadata
2145  // operator and the base pointer of the extract_strided_metadata memref
2146  // input.
2147  rewriter.modifyOpInPlace(op, [&]() {
2148  op.getSourceMutable().assign(extractStridedMetadata.getSource());
2149  });
2150  return success();
2151  }
2152 
2153  // At this point, we know that the back and forth between extract strided
2154  // metadata and reinterpret cast is a noop. However, the final type of the
2155  // reinterpret cast may not be exactly the same as the original memref.
2156  // E.g., it could be changing a dimension from static to dynamic. Check that
2157  // here and add a cast if necessary.
2158  Type srcTy = extractStridedMetadata.getSource().getType();
2159  if (srcTy == op.getResult().getType())
2160  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2161  else
2162  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2163  extractStridedMetadata.getSource());
2164 
2165  return success();
2166  }
2167 };
2168 
2169 struct ReinterpretCastOpConstantFolder
2170  : public OpRewritePattern<ReinterpretCastOp> {
2171 public:
2173 
2174  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2175  PatternRewriter &rewriter) const override {
2176  unsigned srcStaticCount = llvm::count_if(
2177  llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2178  op.getMixedStrides()),
2179  [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2180 
2181  SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2182  SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2183  SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2184 
2185  // TODO: Using counting comparison instead of direct comparison because
2186  // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
2187  // IntegerAttrs, while constifyIndexValues (and therefore
2188  // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
2189  if (srcStaticCount ==
2190  llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2191  [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
2192  return failure();
2193 
2194  auto newReinterpretCast = ReinterpretCastOp::create(
2195  rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2196 
2197  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
2198  return success();
2199  }
2200 };
2201 } // namespace
2202 
2203 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2204  MLIRContext *context) {
2205  results.add<ReinterpretCastOpExtractStridedMetadataFolder,
2206  ReinterpretCastOpConstantFolder>(context);
2207 }
2208 
2209 FailureOr<std::optional<SmallVector<Value>>>
2210 ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2211  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2212 }
2213 
2214 //===----------------------------------------------------------------------===//
2215 // Reassociative reshape ops
2216 //===----------------------------------------------------------------------===//
2217 
2218 void CollapseShapeOp::getAsmResultNames(
2219  function_ref<void(Value, StringRef)> setNameFn) {
2220  setNameFn(getResult(), "collapse_shape");
2221 }
2222 
2223 void ExpandShapeOp::getAsmResultNames(
2224  function_ref<void(Value, StringRef)> setNameFn) {
2225  setNameFn(getResult(), "expand_shape");
2226 }
2227 
2228 LogicalResult ExpandShapeOp::reifyResultShapes(
2229  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2230  reifiedResultShapes = {
2231  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2232  return success();
2233 }
2234 
2235 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2236 /// result and operand. Layout maps are verified separately.
2237 ///
2238 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2239 /// allowed in a reassocation group.
2240 static LogicalResult
2242  ArrayRef<int64_t> expandedShape,
2243  ArrayRef<ReassociationIndices> reassociation,
2244  bool allowMultipleDynamicDimsPerGroup) {
2245  // There must be one reassociation group per collapsed dimension.
2246  if (collapsedShape.size() != reassociation.size())
2247  return op->emitOpError("invalid number of reassociation groups: found ")
2248  << reassociation.size() << ", expected " << collapsedShape.size();
2249 
2250  // The next expected expanded dimension index (while iterating over
2251  // reassociation indices).
2252  int64_t nextDim = 0;
2253  for (const auto &it : llvm::enumerate(reassociation)) {
2254  ReassociationIndices group = it.value();
2255  int64_t collapsedDim = it.index();
2256 
2257  bool foundDynamic = false;
2258  for (int64_t expandedDim : group) {
2259  if (expandedDim != nextDim++)
2260  return op->emitOpError("reassociation indices must be contiguous");
2261 
2262  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2263  return op->emitOpError("reassociation index ")
2264  << expandedDim << " is out of bounds";
2265 
2266  // Check if there are multiple dynamic dims in a reassociation group.
2267  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2268  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2269  return op->emitOpError(
2270  "at most one dimension in a reassociation group may be dynamic");
2271  foundDynamic = true;
2272  }
2273  }
2274 
2275  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2276  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2277  return op->emitOpError("collapsed dim (")
2278  << collapsedDim
2279  << ") must be dynamic if and only if reassociation group is "
2280  "dynamic";
2281 
2282  // If all dims in the reassociation group are static, the size of the
2283  // collapsed dim can be verified.
2284  if (!foundDynamic) {
2285  int64_t groupSize = 1;
2286  for (int64_t expandedDim : group)
2287  groupSize *= expandedShape[expandedDim];
2288  if (groupSize != collapsedShape[collapsedDim])
2289  return op->emitOpError("collapsed dim size (")
2290  << collapsedShape[collapsedDim]
2291  << ") must equal reassociation group size (" << groupSize << ")";
2292  }
2293  }
2294 
2295  if (collapsedShape.empty()) {
2296  // Rank 0: All expanded dimensions must be 1.
2297  for (int64_t d : expandedShape)
2298  if (d != 1)
2299  return op->emitOpError(
2300  "rank 0 memrefs can only be extended/collapsed with/from ones");
2301  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2302  // Rank >= 1: Number of dimensions among all reassociation groups must match
2303  // the result memref rank.
2304  return op->emitOpError("expanded rank (")
2305  << expandedShape.size()
2306  << ") inconsistent with number of reassociation indices (" << nextDim
2307  << ")";
2308  }
2309 
2310  return success();
2311 }
2312 
2313 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2314  return getSymbolLessAffineMaps(getReassociationExprs());
2315 }
2316 
2317 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2319  getReassociationIndices());
2320 }
2321 
2322 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2323  return getSymbolLessAffineMaps(getReassociationExprs());
2324 }
2325 
2326 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2328  getReassociationIndices());
2329 }
2330 
2331 /// Compute the layout map after expanding a given source MemRef type with the
2332 /// specified reassociation indices.
2333 static FailureOr<StridedLayoutAttr>
2334 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2335  ArrayRef<ReassociationIndices> reassociation) {
2336  int64_t srcOffset;
2337  SmallVector<int64_t> srcStrides;
2338  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2339  return failure();
2340  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2341 
2342  // 1-1 mapping between srcStrides and reassociation packs.
2343  // Each srcStride starts with the given value and gets expanded according to
2344  // the proper entries in resultShape.
2345  // Example:
2346  // srcStrides = [10000, 1 , 100 ],
2347  // reassociations = [ [0], [1], [2, 3, 4]],
2348  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2349  // -> For the purpose of stride calculation, the useful sizes are:
2350  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2351  // resultStrides = [10000, 1, 600, 200, 100]
2352  // Note that a stride does not get expanded along the first entry of each
2353  // shape pack.
2354  SmallVector<int64_t> reverseResultStrides;
2355  reverseResultStrides.reserve(resultShape.size());
2356  unsigned shapeIndex = resultShape.size() - 1;
2357  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2358  ReassociationIndices reassoc = std::get<0>(it);
2359  int64_t currentStrideToExpand = std::get<1>(it);
2360  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2361  reverseResultStrides.push_back(currentStrideToExpand);
2362  currentStrideToExpand =
2363  (SaturatedInteger::wrap(currentStrideToExpand) *
2364  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2365  .asInteger();
2366  }
2367  }
2368  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2369  resultStrides.resize(resultShape.size(), 1);
2370  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2371 }
2372 
2373 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2374  MemRefType srcType, ArrayRef<int64_t> resultShape,
2375  ArrayRef<ReassociationIndices> reassociation) {
2376  if (srcType.getLayout().isIdentity()) {
2377  // If the source is contiguous (i.e., no layout map specified), so is the
2378  // result.
2379  MemRefLayoutAttrInterface layout;
2380  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2381  srcType.getMemorySpace());
2382  }
2383 
2384  // Source may not be contiguous. Compute the layout map.
2385  FailureOr<StridedLayoutAttr> computedLayout =
2386  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2387  if (failed(computedLayout))
2388  return failure();
2389  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2390  srcType.getMemorySpace());
2391 }
2392 
2393 FailureOr<SmallVector<OpFoldResult>>
2394 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2395  MemRefType expandedType,
2396  ArrayRef<ReassociationIndices> reassociation,
2397  ArrayRef<OpFoldResult> inputShape) {
2398  std::optional<SmallVector<OpFoldResult>> outputShape =
2399  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2400  inputShape);
2401  if (!outputShape)
2402  return failure();
2403  return *outputShape;
2404 }
2405 
2406 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2407  Type resultType, Value src,
2408  ArrayRef<ReassociationIndices> reassociation,
2409  ArrayRef<OpFoldResult> outputShape) {
2410  auto [staticOutputShape, dynamicOutputShape] =
2412  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2413  getReassociationIndicesAttribute(builder, reassociation),
2414  dynamicOutputShape, staticOutputShape);
2415 }
2416 
2417 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2418  Type resultType, Value src,
2419  ArrayRef<ReassociationIndices> reassociation) {
2420  SmallVector<OpFoldResult> inputShape =
2421  getMixedSizes(builder, result.location, src);
2422  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2423  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2424  builder, result.location, memrefResultTy, reassociation, inputShape);
2425  // Failure of this assertion usually indicates presence of multiple
2426  // dynamic dimensions in the same reassociation group.
2427  assert(succeeded(outputShape) && "unable to infer output shape");
2428  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2429 }
2430 
2431 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2432  ArrayRef<int64_t> resultShape, Value src,
2433  ArrayRef<ReassociationIndices> reassociation) {
2434  // Only ranked memref source values are supported.
2435  auto srcType = llvm::cast<MemRefType>(src.getType());
2436  FailureOr<MemRefType> resultType =
2437  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2438  // Failure of this assertion usually indicates a problem with the source
2439  // type, e.g., could not get strides/offset.
2440  assert(succeeded(resultType) && "could not compute layout");
2441  build(builder, result, *resultType, src, reassociation);
2442 }
2443 
2444 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2445  ArrayRef<int64_t> resultShape, Value src,
2446  ArrayRef<ReassociationIndices> reassociation,
2447  ArrayRef<OpFoldResult> outputShape) {
2448  // Only ranked memref source values are supported.
2449  auto srcType = llvm::cast<MemRefType>(src.getType());
2450  FailureOr<MemRefType> resultType =
2451  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2452  // Failure of this assertion usually indicates a problem with the source
2453  // type, e.g., could not get strides/offset.
2454  assert(succeeded(resultType) && "could not compute layout");
2455  build(builder, result, *resultType, src, reassociation, outputShape);
2456 }
2457 
2458 LogicalResult ExpandShapeOp::verify() {
2459  MemRefType srcType = getSrcType();
2460  MemRefType resultType = getResultType();
2461 
2462  if (srcType.getRank() > resultType.getRank()) {
2463  auto r0 = srcType.getRank();
2464  auto r1 = resultType.getRank();
2465  return emitOpError("has source rank ")
2466  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2467  << r0 << " > " << r1 << ").";
2468  }
2469 
2470  // Verify result shape.
2471  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2472  resultType.getShape(),
2473  getReassociationIndices(),
2474  /*allowMultipleDynamicDimsPerGroup=*/true)))
2475  return failure();
2476 
2477  // Compute expected result type (including layout map).
2478  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2479  srcType, resultType.getShape(), getReassociationIndices());
2480  if (failed(expectedResultType))
2481  return emitOpError("invalid source layout map");
2482 
2483  // Check actual result type.
2484  if (*expectedResultType != resultType)
2485  return emitOpError("expected expanded type to be ")
2486  << *expectedResultType << " but found " << resultType;
2487 
2488  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2489  return emitOpError("expected number of static shape bounds to be equal to "
2490  "the output rank (")
2491  << resultType.getRank() << ") but found "
2492  << getStaticOutputShape().size() << " inputs instead";
2493 
2494  if ((int64_t)getOutputShape().size() !=
2495  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2496  return emitOpError("mismatch in dynamic dims in output_shape and "
2497  "static_output_shape: static_output_shape has ")
2498  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2499  << " dynamic dims while output_shape has " << getOutputShape().size()
2500  << " values";
2501 
2502  // Verify if provided output shapes are in agreement with output type.
2503  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2504  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2505  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2506  if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2507  return emitOpError("invalid output shape provided at pos ") << pos;
2508  }
2509  }
2510 
2511  return success();
2512 }
2513 
2514 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2515  MLIRContext *context) {
2516  results.add<
2519 }
2520 
2521 FailureOr<std::optional<SmallVector<Value>>>
2522 ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2523  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2524 }
2525 
2526 /// Compute the layout map after collapsing a given source MemRef type with the
2527 /// specified reassociation indices.
2528 ///
2529 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2530 /// not possible to check this by inspecting a MemRefType in the general case.
2531 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2532 /// be valid (and thus accepted by this function) unless `strict = true`.
2533 static FailureOr<StridedLayoutAttr>
2534 computeCollapsedLayoutMap(MemRefType srcType,
2535  ArrayRef<ReassociationIndices> reassociation,
2536  bool strict = false) {
2537  int64_t srcOffset;
2538  SmallVector<int64_t> srcStrides;
2539  auto srcShape = srcType.getShape();
2540  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2541  return failure();
2542 
2543  // The result stride of a reassociation group is the stride of the last entry
2544  // of the reassociation. (TODO: Should be the minimum stride in the
2545  // reassociation because strides are not necessarily sorted. E.g., when using
2546  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2547  // strides are meaningless and could have any arbitrary value.
2548  SmallVector<int64_t> resultStrides;
2549  resultStrides.reserve(reassociation.size());
2550  for (const ReassociationIndices &reassoc : reassociation) {
2551  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2552  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2553  ref = ref.drop_back();
2554  if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2555  resultStrides.push_back(srcStrides[ref.back()]);
2556  } else {
2557  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2558  // the corresponding stride may have to be skipped. (See above comment.)
2559  // Therefore, the result stride cannot be statically determined and must
2560  // be dynamic.
2561  resultStrides.push_back(ShapedType::kDynamic);
2562  }
2563  }
2564 
2565  // Validate that each reassociation group is contiguous.
2566  unsigned resultStrideIndex = resultStrides.size() - 1;
2567  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2568  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2569  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2570  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2571  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2572 
2573  // Both source and result stride must have the same static value. In that
2574  // case, we can be sure, that the dimensions are collapsible (because they
2575  // are contiguous).
2576  // If `strict = false` (default during op verification), we accept cases
2577  // where one or both strides are dynamic. This is best effort: We reject
2578  // ops where obviously non-contiguous dims are collapsed, but accept ops
2579  // where we cannot be sure statically. Such ops may fail at runtime. See
2580  // the op documentation for details.
2581  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2582  if (strict && (stride.saturated || srcStride.saturated))
2583  return failure();
2584 
2585  // Dimensions of size 1 should be skipped, because their strides are
2586  // meaningless and could have any arbitrary value.
2587  if (srcShape[idx - 1] == 1)
2588  continue;
2589 
2590  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2591  return failure();
2592  }
2593  }
2594  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2595 }
2596 
2597 bool CollapseShapeOp::isGuaranteedCollapsible(
2598  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2599  // MemRefs with identity layout are always collapsible.
2600  if (srcType.getLayout().isIdentity())
2601  return true;
2602 
2603  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2604  /*strict=*/true));
2605 }
2606 
2607 MemRefType CollapseShapeOp::computeCollapsedType(
2608  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2609  SmallVector<int64_t> resultShape;
2610  resultShape.reserve(reassociation.size());
2611  for (const ReassociationIndices &group : reassociation) {
2612  auto groupSize = SaturatedInteger::wrap(1);
2613  for (int64_t srcDim : group)
2614  groupSize =
2615  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2616  resultShape.push_back(groupSize.asInteger());
2617  }
2618 
2619  if (srcType.getLayout().isIdentity()) {
2620  // If the source is contiguous (i.e., no layout map specified), so is the
2621  // result.
2622  MemRefLayoutAttrInterface layout;
2623  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2624  srcType.getMemorySpace());
2625  }
2626 
2627  // Source may not be fully contiguous. Compute the layout map.
2628  // Note: Dimensions that are collapsed into a single dim are assumed to be
2629  // contiguous.
2630  FailureOr<StridedLayoutAttr> computedLayout =
2631  computeCollapsedLayoutMap(srcType, reassociation);
2632  assert(succeeded(computedLayout) &&
2633  "invalid source layout map or collapsing non-contiguous dims");
2634  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2635  srcType.getMemorySpace());
2636 }
2637 
2638 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2639  ArrayRef<ReassociationIndices> reassociation,
2640  ArrayRef<NamedAttribute> attrs) {
2641  auto srcType = llvm::cast<MemRefType>(src.getType());
2642  MemRefType resultType =
2643  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2645  getReassociationIndicesAttribute(b, reassociation));
2646  build(b, result, resultType, src, attrs);
2647 }
2648 
2649 LogicalResult CollapseShapeOp::verify() {
2650  MemRefType srcType = getSrcType();
2651  MemRefType resultType = getResultType();
2652 
2653  if (srcType.getRank() < resultType.getRank()) {
2654  auto r0 = srcType.getRank();
2655  auto r1 = resultType.getRank();
2656  return emitOpError("has source rank ")
2657  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2658  << r0 << " < " << r1 << ").";
2659  }
2660 
2661  // Verify result shape.
2662  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2663  srcType.getShape(), getReassociationIndices(),
2664  /*allowMultipleDynamicDimsPerGroup=*/true)))
2665  return failure();
2666 
2667  // Compute expected result type (including layout map).
2668  MemRefType expectedResultType;
2669  if (srcType.getLayout().isIdentity()) {
2670  // If the source is contiguous (i.e., no layout map specified), so is the
2671  // result.
2672  MemRefLayoutAttrInterface layout;
2673  expectedResultType =
2674  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2675  srcType.getMemorySpace());
2676  } else {
2677  // Source may not be fully contiguous. Compute the layout map.
2678  // Note: Dimensions that are collapsed into a single dim are assumed to be
2679  // contiguous.
2680  FailureOr<StridedLayoutAttr> computedLayout =
2681  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2682  if (failed(computedLayout))
2683  return emitOpError(
2684  "invalid source layout map or collapsing non-contiguous dims");
2685  expectedResultType =
2686  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2687  *computedLayout, srcType.getMemorySpace());
2688  }
2689 
2690  if (expectedResultType != resultType)
2691  return emitOpError("expected collapsed type to be ")
2692  << expectedResultType << " but found " << resultType;
2693 
2694  return success();
2695 }
2696 
2698  : public OpRewritePattern<CollapseShapeOp> {
2699 public:
2701 
2702  LogicalResult matchAndRewrite(CollapseShapeOp op,
2703  PatternRewriter &rewriter) const override {
2704  auto cast = op.getOperand().getDefiningOp<CastOp>();
2705  if (!cast)
2706  return failure();
2707 
2708  if (!CastOp::canFoldIntoConsumerOp(cast))
2709  return failure();
2710 
2711  Type newResultType = CollapseShapeOp::computeCollapsedType(
2712  llvm::cast<MemRefType>(cast.getOperand().getType()),
2713  op.getReassociationIndices());
2714 
2715  if (newResultType == op.getResultType()) {
2716  rewriter.modifyOpInPlace(
2717  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2718  } else {
2719  Value newOp =
2720  CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2721  op.getReassociationIndices());
2722  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2723  }
2724  return success();
2725  }
2726 };
2727 
2728 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2729  MLIRContext *context) {
2730  results.add<
2732  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2733  memref::DimOp, MemRefType>,
2735 }
2736 
2737 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2738  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2739  adaptor.getOperands());
2740 }
2741 
2742 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2743  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2744  adaptor.getOperands());
2745 }
2746 
2747 FailureOr<std::optional<SmallVector<Value>>>
2748 CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2749  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2750 }
2751 
2752 //===----------------------------------------------------------------------===//
2753 // ReshapeOp
2754 //===----------------------------------------------------------------------===//
2755 
2756 void ReshapeOp::getAsmResultNames(
2757  function_ref<void(Value, StringRef)> setNameFn) {
2758  setNameFn(getResult(), "reshape");
2759 }
2760 
2761 LogicalResult ReshapeOp::verify() {
2762  Type operandType = getSource().getType();
2763  Type resultType = getResult().getType();
2764 
2765  Type operandElementType =
2766  llvm::cast<ShapedType>(operandType).getElementType();
2767  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2768  if (operandElementType != resultElementType)
2769  return emitOpError("element types of source and destination memref "
2770  "types should be the same");
2771 
2772  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2773  if (!operandMemRefType.getLayout().isIdentity())
2774  return emitOpError("source memref type should have identity affine map");
2775 
2776  int64_t shapeSize =
2777  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2778  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2779  if (resultMemRefType) {
2780  if (!resultMemRefType.getLayout().isIdentity())
2781  return emitOpError("result memref type should have identity affine map");
2782  if (shapeSize == ShapedType::kDynamic)
2783  return emitOpError("cannot use shape operand with dynamic length to "
2784  "reshape to statically-ranked memref type");
2785  if (shapeSize != resultMemRefType.getRank())
2786  return emitOpError(
2787  "length of shape operand differs from the result's memref rank");
2788  }
2789  return success();
2790 }
2791 
2792 FailureOr<std::optional<SmallVector<Value>>>
2793 ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
2794  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2795 }
2796 
2797 //===----------------------------------------------------------------------===//
2798 // StoreOp
2799 //===----------------------------------------------------------------------===//
2800 
2801 LogicalResult StoreOp::verify() {
2802  if (getNumOperands() != 2 + getMemRefType().getRank())
2803  return emitOpError("store index operand count not equal to memref rank");
2804 
2805  return success();
2806 }
2807 
2808 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2809  SmallVectorImpl<OpFoldResult> &results) {
2810  /// store(memrefcast) -> store
2811  return foldMemRefCast(*this, getValueToStore());
2812 }
2813 
2814 FailureOr<std::optional<SmallVector<Value>>>
2815 StoreOp::bubbleDownCasts(OpBuilder &builder) {
2816  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
2817  ValueRange());
2818 }
2819 
2820 //===----------------------------------------------------------------------===//
2821 // SubViewOp
2822 //===----------------------------------------------------------------------===//
2823 
2824 void SubViewOp::getAsmResultNames(
2825  function_ref<void(Value, StringRef)> setNameFn) {
2826  setNameFn(getResult(), "subview");
2827 }
2828 
2829 /// A subview result type can be fully inferred from the source type and the
2830 /// static representation of offsets, sizes and strides. Special sentinels
2831 /// encode the dynamic case.
2832 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2833  ArrayRef<int64_t> staticOffsets,
2834  ArrayRef<int64_t> staticSizes,
2835  ArrayRef<int64_t> staticStrides) {
2836  unsigned rank = sourceMemRefType.getRank();
2837  (void)rank;
2838  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2839  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2840  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2841 
2842  // Extract source offset and strides.
2843  auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2844 
2845  // Compute target offset whose value is:
2846  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2847  int64_t targetOffset = sourceOffset;
2848  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2849  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2850  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2851  SaturatedInteger::wrap(staticOffset) *
2852  SaturatedInteger::wrap(sourceStride))
2853  .asInteger();
2854  }
2855 
2856  // Compute target stride whose value is:
2857  // `sourceStrides_i * staticStrides_i`.
2858  SmallVector<int64_t, 4> targetStrides;
2859  targetStrides.reserve(staticOffsets.size());
2860  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2861  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2862  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2863  SaturatedInteger::wrap(staticStride))
2864  .asInteger());
2865  }
2866 
2867  // The type is now known.
2868  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2869  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2870  targetOffset, targetStrides),
2871  sourceMemRefType.getMemorySpace());
2872 }
2873 
2874 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2875  ArrayRef<OpFoldResult> offsets,
2876  ArrayRef<OpFoldResult> sizes,
2877  ArrayRef<OpFoldResult> strides) {
2878  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2879  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2880  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2881  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2882  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2883  if (!hasValidSizesOffsets(staticOffsets))
2884  return {};
2885  if (!hasValidSizesOffsets(staticSizes))
2886  return {};
2887  if (!hasValidStrides(staticStrides))
2888  return {};
2889  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2890  staticSizes, staticStrides);
2891 }
2892 
2893 MemRefType SubViewOp::inferRankReducedResultType(
2894  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2895  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2896  ArrayRef<int64_t> strides) {
2897  MemRefType inferredType =
2898  inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2899  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2900  "expected ");
2901  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2902  return inferredType;
2903 
2904  // Compute which dimensions are dropped.
2905  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2906  computeRankReductionMask(inferredType.getShape(), resultShape);
2907  assert(dimsToProject.has_value() && "invalid rank reduction");
2908 
2909  // Compute the layout and result type.
2910  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2911  SmallVector<int64_t> rankReducedStrides;
2912  rankReducedStrides.reserve(resultShape.size());
2913  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2914  if (!dimsToProject->contains(idx))
2915  rankReducedStrides.push_back(value);
2916  }
2917  return MemRefType::get(resultShape, inferredType.getElementType(),
2918  StridedLayoutAttr::get(inferredLayout.getContext(),
2919  inferredLayout.getOffset(),
2920  rankReducedStrides),
2921  inferredType.getMemorySpace());
2922 }
2923 
2924 MemRefType SubViewOp::inferRankReducedResultType(
2925  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2927  ArrayRef<OpFoldResult> strides) {
2928  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2929  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2930  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2931  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2932  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2933  return SubViewOp::inferRankReducedResultType(
2934  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2935  staticStrides);
2936 }
2937 
2938 // Build a SubViewOp with mixed static and dynamic entries and custom result
2939 // type. If the type passed is nullptr, it is inferred.
2940 void SubViewOp::build(OpBuilder &b, OperationState &result,
2941  MemRefType resultType, Value source,
2942  ArrayRef<OpFoldResult> offsets,
2943  ArrayRef<OpFoldResult> sizes,
2944  ArrayRef<OpFoldResult> strides,
2945  ArrayRef<NamedAttribute> attrs) {
2946  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2947  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2948  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2949  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2950  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2951  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2952  // Structuring implementation this way avoids duplication between builders.
2953  if (!resultType) {
2954  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2955  staticSizes, staticStrides);
2956  }
2957  result.addAttributes(attrs);
2958  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2959  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2960  b.getDenseI64ArrayAttr(staticSizes),
2961  b.getDenseI64ArrayAttr(staticStrides));
2962 }
2963 
2964 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2965 // type.
2966 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2967  ArrayRef<OpFoldResult> offsets,
2968  ArrayRef<OpFoldResult> sizes,
2969  ArrayRef<OpFoldResult> strides,
2970  ArrayRef<NamedAttribute> attrs) {
2971  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2972 }
2973 
2974 // Build a SubViewOp with static entries and inferred result type.
2975 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2976  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2977  ArrayRef<int64_t> strides,
2978  ArrayRef<NamedAttribute> attrs) {
2979  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2980  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2981  return b.getI64IntegerAttr(v);
2982  }));
2983  SmallVector<OpFoldResult> sizeValues =
2984  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2985  return b.getI64IntegerAttr(v);
2986  }));
2987  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2988  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2989  return b.getI64IntegerAttr(v);
2990  }));
2991  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2992 }
2993 
2994 // Build a SubViewOp with dynamic entries and custom result type. If the
2995 // type passed is nullptr, it is inferred.
2996 void SubViewOp::build(OpBuilder &b, OperationState &result,
2997  MemRefType resultType, Value source,
2998  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2999  ArrayRef<int64_t> strides,
3000  ArrayRef<NamedAttribute> attrs) {
3001  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3002  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
3003  return b.getI64IntegerAttr(v);
3004  }));
3005  SmallVector<OpFoldResult> sizeValues =
3006  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
3007  return b.getI64IntegerAttr(v);
3008  }));
3009  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3010  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
3011  return b.getI64IntegerAttr(v);
3012  }));
3013  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
3014  attrs);
3015 }
3016 
3017 // Build a SubViewOp with dynamic entries and custom result type. If the type
3018 // passed is nullptr, it is inferred.
3019 void SubViewOp::build(OpBuilder &b, OperationState &result,
3020  MemRefType resultType, Value source, ValueRange offsets,
3021  ValueRange sizes, ValueRange strides,
3022  ArrayRef<NamedAttribute> attrs) {
3023  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3024  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3025  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3026  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3027  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3028  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3029  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
3030 }
3031 
3032 // Build a SubViewOp with dynamic entries and inferred result type.
3033 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3034  ValueRange offsets, ValueRange sizes, ValueRange strides,
3035  ArrayRef<NamedAttribute> attrs) {
3036  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3037 }
3038 
3039 /// For ViewLikeOpInterface.
3040 Value SubViewOp::getViewSource() { return getSource(); }
3041 
3042 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
3043 /// static value).
3044 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
3045  int64_t t1Offset, t2Offset;
3046  SmallVector<int64_t> t1Strides, t2Strides;
3047  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3048  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3049  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3050 }
3051 
3052 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
3053 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
3054 /// marked as dropped in `droppedDims`.
3055 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
3056  const llvm::SmallBitVector &droppedDims) {
3057  assert(size_t(t1.getRank()) == droppedDims.size() &&
3058  "incorrect number of bits");
3059  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3060  "incorrect number of dropped dims");
3061  int64_t t1Offset, t2Offset;
3062  SmallVector<int64_t> t1Strides, t2Strides;
3063  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3064  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3065  if (failed(res1) || failed(res2))
3066  return false;
3067  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
3068  if (droppedDims[i])
3069  continue;
3070  if (t1Strides[i] != t2Strides[j])
3071  return false;
3072  ++j;
3073  }
3074  return true;
3075 }
3076 
3078  SubViewOp op, Type expectedType) {
3079  auto memrefType = llvm::cast<ShapedType>(expectedType);
3080  switch (result) {
3082  return success();
3084  return op->emitError("expected result rank to be smaller or equal to ")
3085  << "the source rank, but got " << op.getType();
3087  return op->emitError("expected result type to be ")
3088  << expectedType
3089  << " or a rank-reduced version. (mismatch of result sizes), but got "
3090  << op.getType();
3092  return op->emitError("expected result element type to be ")
3093  << memrefType.getElementType() << ", but got " << op.getType();
3095  return op->emitError(
3096  "expected result and source memory spaces to match, but got ")
3097  << op.getType();
3099  return op->emitError("expected result type to be ")
3100  << expectedType
3101  << " or a rank-reduced version. (mismatch of result layout), but "
3102  "got "
3103  << op.getType();
3104  }
3105  llvm_unreachable("unexpected subview verification result");
3106 }
3107 
3108 /// Verifier for SubViewOp.
3109 LogicalResult SubViewOp::verify() {
3110  MemRefType baseType = getSourceType();
3111  MemRefType subViewType = getType();
3112  ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3113  ArrayRef<int64_t> staticSizes = getStaticSizes();
3114  ArrayRef<int64_t> staticStrides = getStaticStrides();
3115 
3116  // The base memref and the view memref should be in the same memory space.
3117  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3118  return emitError("different memory spaces specified for base memref "
3119  "type ")
3120  << baseType << " and subview memref type " << subViewType;
3121 
3122  // Verify that the base memref type has a strided layout map.
3123  if (!baseType.isStrided())
3124  return emitError("base type ") << baseType << " is not strided";
3125 
3126  // Compute the expected result type, assuming that there are no rank
3127  // reductions.
3128  MemRefType expectedType = SubViewOp::inferResultType(
3129  baseType, staticOffsets, staticSizes, staticStrides);
3130 
3131  // Verify all properties of a shaped type: rank, element type and dimension
3132  // sizes. This takes into account potential rank reductions.
3133  auto shapedTypeVerification = isRankReducedType(
3134  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
3135  if (shapedTypeVerification != SliceVerificationResult::Success)
3136  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
3137 
3138  // Make sure that the memory space did not change.
3139  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3141  *this, expectedType);
3142 
3143  // Verify the offset of the layout map.
3144  if (!haveCompatibleOffsets(expectedType, subViewType))
3146  *this, expectedType);
3147 
3148  // The only thing that's left to verify now are the strides. First, compute
3149  // the unused dimensions due to rank reductions. We have to look at sizes and
3150  // strides to decide which dimensions were dropped. This function also
3151  // partially verifies strides in case of rank reductions.
3152  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3153  getMixedSizes());
3154  if (failed(unusedDims))
3156  *this, expectedType);
3157 
3158  // Strides must match.
3159  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3161  *this, expectedType);
3162 
3163  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3164  // to the base memref.
3165  SliceBoundsVerificationResult boundsResult =
3166  verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3167  staticStrides, /*generateErrorMessage=*/true);
3168  if (!boundsResult.isValid)
3169  return getOperation()->emitError(boundsResult.errorMessage);
3170 
3171  return success();
3172 }
3173 
3174 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
3175  return os << "range " << range.offset << ":" << range.size << ":"
3176  << range.stride;
3177 }
3178 
3179 /// Return the list of Range (i.e. offset, size, stride). Each Range
3180 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3181 /// with `b` at location `loc`.
3182 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3183  OpBuilder &b, Location loc) {
3184  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3185  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3186  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3188  unsigned rank = ranks[0];
3189  res.reserve(rank);
3190  for (unsigned idx = 0; idx < rank; ++idx) {
3191  Value offset =
3192  op.isDynamicOffset(idx)
3193  ? op.getDynamicOffset(idx)
3194  : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
3195  Value size =
3196  op.isDynamicSize(idx)
3197  ? op.getDynamicSize(idx)
3198  : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
3199  Value stride =
3200  op.isDynamicStride(idx)
3201  ? op.getDynamicStride(idx)
3202  : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
3203  res.emplace_back(Range{offset, size, stride});
3204  }
3205  return res;
3206 }
3207 
3208 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3209 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3210 /// the rank of the inferred result type if `currentResultType` is lower rank
3211 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3212 /// together with the result type. In this case, it is important to compute
3213 /// the dropped dimensions using `currentSourceType` whose strides align with
3214 /// `currentResultType`.
3216  MemRefType currentResultType, MemRefType currentSourceType,
3217  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3218  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3219  MemRefType nonRankReducedType = SubViewOp::inferResultType(
3220  sourceType, mixedOffsets, mixedSizes, mixedStrides);
3221  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3222  currentSourceType, currentResultType, mixedSizes);
3223  if (failed(unusedDims))
3224  return nullptr;
3225 
3226  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3227  SmallVector<int64_t> shape, strides;
3228  unsigned numDimsAfterReduction =
3229  nonRankReducedType.getRank() - unusedDims->count();
3230  shape.reserve(numDimsAfterReduction);
3231  strides.reserve(numDimsAfterReduction);
3232  for (const auto &[idx, size, stride] :
3233  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3234  nonRankReducedType.getShape(), layout.getStrides())) {
3235  if (unusedDims->test(idx))
3236  continue;
3237  shape.push_back(size);
3238  strides.push_back(stride);
3239  }
3240 
3241  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3242  StridedLayoutAttr::get(sourceType.getContext(),
3243  layout.getOffset(), strides),
3244  nonRankReducedType.getMemorySpace());
3245 }
3246 
3248  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3249  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3250  unsigned rank = memrefType.getRank();
3251  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3252  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3253  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3254  MemRefType targetType = SubViewOp::inferRankReducedResultType(
3255  targetShape, memrefType, offsets, sizes, strides);
3256  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3257  sizes, strides);
3258 }
3259 
3260 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3261  Value value,
3262  ArrayRef<int64_t> desiredShape) {
3263  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3264  assert(sourceMemrefType && "not a ranked memref type");
3265  auto sourceShape = sourceMemrefType.getShape();
3266  if (sourceShape.equals(desiredShape))
3267  return value;
3268  auto maybeRankReductionMask =
3269  mlir::computeRankReductionMask(sourceShape, desiredShape);
3270  if (!maybeRankReductionMask)
3271  return failure();
3272  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3273 }
3274 
3275 /// Helper method to check if a `subview` operation is trivially a no-op. This
3276 /// is the case if the all offsets are zero, all strides are 1, and the source
3277 /// shape is same as the size of the subview. In such cases, the subview can
3278 /// be folded into its source.
3279 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3280  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3281  return false;
3282 
3283  auto mixedOffsets = subViewOp.getMixedOffsets();
3284  auto mixedSizes = subViewOp.getMixedSizes();
3285  auto mixedStrides = subViewOp.getMixedStrides();
3286 
3287  // Check offsets are zero.
3288  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3289  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3290  return !intValue || intValue.value() != 0;
3291  }))
3292  return false;
3293 
3294  // Check strides are one.
3295  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3296  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3297  return !intValue || intValue.value() != 1;
3298  }))
3299  return false;
3300 
3301  // Check all size values are static and matches the (static) source shape.
3302  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3303  for (const auto &size : llvm::enumerate(mixedSizes)) {
3304  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3305  if (!intValue || *intValue != sourceShape[size.index()])
3306  return false;
3307  }
3308  // All conditions met. The `SubViewOp` is foldable as a no-op.
3309  return true;
3310 }
3311 
3312 namespace {
3313 /// Pattern to rewrite a subview op with MemRefCast arguments.
3314 /// This essentially pushes memref.cast past its consuming subview when
3315 /// `canFoldIntoConsumerOp` is true.
3316 ///
3317 /// Example:
3318 /// ```
3319 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3320 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3321 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3322 /// ```
3323 /// is rewritten into:
3324 /// ```
3325 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3326 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3327 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3328 /// ```
3329 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3330 public:
3332 
3333  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3334  PatternRewriter &rewriter) const override {
3335  // Any constant operand, just return to let SubViewOpConstantFolder kick
3336  // in.
3337  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3338  return matchPattern(operand, matchConstantIndex());
3339  }))
3340  return failure();
3341 
3342  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3343  if (!castOp)
3344  return failure();
3345 
3346  if (!CastOp::canFoldIntoConsumerOp(castOp))
3347  return failure();
3348 
3349  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3350  // the MemRefCastOp source operand type to infer the result type and the
3351  // current SubViewOp source operand type to compute the dropped dimensions
3352  // if the operation is rank-reducing.
3353  auto resultType = getCanonicalSubViewResultType(
3354  subViewOp.getType(), subViewOp.getSourceType(),
3355  llvm::cast<MemRefType>(castOp.getSource().getType()),
3356  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3357  subViewOp.getMixedStrides());
3358  if (!resultType)
3359  return failure();
3360 
3361  Value newSubView = SubViewOp::create(
3362  rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3363  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3364  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3365  subViewOp.getStaticStrides());
3366  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3367  newSubView);
3368  return success();
3369  }
3370 };
3371 
3372 /// Canonicalize subview ops that are no-ops. When the source shape is not
3373 /// same as a result shape due to use of `affine_map`.
3374 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3375 public:
3377 
3378  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3379  PatternRewriter &rewriter) const override {
3380  if (!isTrivialSubViewOp(subViewOp))
3381  return failure();
3382  if (subViewOp.getSourceType() == subViewOp.getType()) {
3383  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3384  return success();
3385  }
3386  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3387  subViewOp.getSource());
3388  return success();
3389  }
3390 };
3391 } // namespace
3392 
3393 /// Return the canonical type of the result of a subview.
3395  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3396  ArrayRef<OpFoldResult> mixedSizes,
3397  ArrayRef<OpFoldResult> mixedStrides) {
3398  // Infer a memref type without taking into account any rank reductions.
3399  MemRefType resTy = SubViewOp::inferResultType(
3400  op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3401  if (!resTy)
3402  return {};
3403  MemRefType nonReducedType = resTy;
3404 
3405  // Directly return the non-rank reduced type if there are no dropped dims.
3406  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3407  if (droppedDims.none())
3408  return nonReducedType;
3409 
3410  // Take the strides and offset from the non-rank reduced type.
3411  auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3412 
3413  // Drop dims from shape and strides.
3414  SmallVector<int64_t> targetShape;
3415  SmallVector<int64_t> targetStrides;
3416  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3417  if (droppedDims.test(i))
3418  continue;
3419  targetStrides.push_back(nonReducedStrides[i]);
3420  targetShape.push_back(nonReducedType.getDimSize(i));
3421  }
3422 
3423  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3424  StridedLayoutAttr::get(nonReducedType.getContext(),
3425  offset, targetStrides),
3426  nonReducedType.getMemorySpace());
3427  }
3428 };
3429 
3430 /// A canonicalizer wrapper to replace SubViewOps.
3432  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3433  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3434  }
3435 };
3436 
3437 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3438  MLIRContext *context) {
3439  results
3442  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3443 }
3444 
3445 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3446  MemRefType sourceMemrefType = getSource().getType();
3447  MemRefType resultMemrefType = getResult().getType();
3448  auto resultLayout =
3449  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3450 
3451  if (resultMemrefType == sourceMemrefType &&
3452  resultMemrefType.hasStaticShape() &&
3453  (!resultLayout || resultLayout.hasStaticLayout())) {
3454  return getViewSource();
3455  }
3456 
3457  // Fold subview(subview(x)), where both subviews have the same size and the
3458  // second subview's offsets are all zero. (I.e., the second subview is a
3459  // no-op.)
3460  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3461  auto srcSizes = srcSubview.getMixedSizes();
3462  auto sizes = getMixedSizes();
3463  auto offsets = getMixedOffsets();
3464  bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3465  auto strides = getMixedStrides();
3466  bool allStridesOne = llvm::all_of(strides, isOneInteger);
3467  bool allSizesSame = llvm::equal(sizes, srcSizes);
3468  if (allOffsetsZero && allStridesOne && allSizesSame &&
3469  resultMemrefType == sourceMemrefType)
3470  return getViewSource();
3471  }
3472 
3473  return {};
3474 }
3475 
3476 FailureOr<std::optional<SmallVector<Value>>>
3477 SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3478  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3479 }
3480 
3481 void SubViewOp::inferStridedMetadataRanges(
3482  ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
3483  SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
3484  auto isUninitialized =
3485  +[](IntegerValueRange range) { return range.isUninitialized(); };
3486 
3487  // Bail early if any of the operands metadata is not ready:
3488  SmallVector<IntegerValueRange> offsetOperands =
3489  getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
3490  if (llvm::any_of(offsetOperands, isUninitialized))
3491  return;
3492 
3493  SmallVector<IntegerValueRange> sizeOperands =
3494  getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
3495  if (llvm::any_of(sizeOperands, isUninitialized))
3496  return;
3497 
3498  SmallVector<IntegerValueRange> stridesOperands =
3499  getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
3500  if (llvm::any_of(stridesOperands, isUninitialized))
3501  return;
3502 
3503  StridedMetadataRange sourceRange =
3504  ranges[getSourceMutable().getOperandNumber()];
3505  if (sourceRange.isUninitialized())
3506  return;
3507 
3508  ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
3509 
3510  // Get the dropped dims.
3511  llvm::SmallBitVector droppedDims = getDroppedDims();
3512 
3513  // Compute the new offset, strides and sizes.
3514  ConstantIntRanges offset = sourceRange.getOffsets()[0];
3515  SmallVector<ConstantIntRanges> strides, sizes;
3516 
3517  for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3518  bool dropped = droppedDims.test(i);
3519  // Compute the new offset.
3520  ConstantIntRanges off =
3521  intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
3522  offset = intrange::inferAdd({offset, off});
3523 
3524  // Skip dropped dimensions.
3525  if (dropped)
3526  continue;
3527  // Multiply the strides.
3528  strides.push_back(
3529  intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
3530  // Get the sizes.
3531  sizes.push_back(sizeOperands[i].getValue());
3532  }
3533 
3534  setMetadata(getResult(),
3536  SmallVector<ConstantIntRanges>({std::move(offset)}),
3537  std::move(sizes), std::move(strides)));
3538 }
3539 
3540 //===----------------------------------------------------------------------===//
3541 // TransposeOp
3542 //===----------------------------------------------------------------------===//
3543 
3544 void TransposeOp::getAsmResultNames(
3545  function_ref<void(Value, StringRef)> setNameFn) {
3546  setNameFn(getResult(), "transpose");
3547 }
3548 
3549 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3550 static MemRefType inferTransposeResultType(MemRefType memRefType,
3551  AffineMap permutationMap) {
3552  auto originalSizes = memRefType.getShape();
3553  auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3554  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3555 
3556  // Compute permuted sizes and strides.
3557  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3558  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3559 
3560  return MemRefType::Builder(memRefType)
3561  .setShape(sizes)
3562  .setLayout(
3563  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3564 }
3565 
3566 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3567  AffineMapAttr permutation,
3568  ArrayRef<NamedAttribute> attrs) {
3569  auto permutationMap = permutation.getValue();
3570  assert(permutationMap);
3571 
3572  auto memRefType = llvm::cast<MemRefType>(in.getType());
3573  // Compute result type.
3574  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3575 
3576  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3577  build(b, result, resultType, in, attrs);
3578 }
3579 
3580 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3582  p << " " << getIn() << " " << getPermutation();
3583  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3584  p << " : " << getIn().getType() << " to " << getType();
3585 }
3586 
3587 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3589  AffineMap permutation;
3590  MemRefType srcType, dstType;
3591  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3592  parser.parseOptionalAttrDict(result.attributes) ||
3593  parser.parseColonType(srcType) ||
3594  parser.resolveOperand(in, srcType, result.operands) ||
3595  parser.parseKeywordType("to", dstType) ||
3596  parser.addTypeToList(dstType, result.types))
3597  return failure();
3598 
3599  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3600  AffineMapAttr::get(permutation));
3601  return success();
3602 }
3603 
3604 LogicalResult TransposeOp::verify() {
3605  if (!getPermutation().isPermutation())
3606  return emitOpError("expected a permutation map");
3607  if (getPermutation().getNumDims() != getIn().getType().getRank())
3608  return emitOpError("expected a permutation map of same rank as the input");
3609 
3610  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3611  auto resultType = llvm::cast<MemRefType>(getType());
3612  auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3613  .canonicalizeStridedLayout();
3614 
3615  if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3616  return emitOpError("result type ")
3617  << resultType
3618  << " is not equivalent to the canonical transposed input type "
3619  << canonicalResultType;
3620  return success();
3621 }
3622 
3623 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3624  // First check for identity permutation, we can fold it away if input and
3625  // result types are identical already.
3626  if (getPermutation().isIdentity() && getType() == getIn().getType())
3627  return getIn();
3628  // Fold two consecutive memref.transpose Ops into one by composing their
3629  // permutation maps.
3630  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3631  AffineMap composedPermutation =
3632  getPermutation().compose(otherTransposeOp.getPermutation());
3633  getInMutable().assign(otherTransposeOp.getIn());
3634  setPermutation(composedPermutation);
3635  return getResult();
3636  }
3637  return {};
3638 }
3639 
3640 FailureOr<std::optional<SmallVector<Value>>>
3641 TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3642  return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
3643 }
3644 
3645 //===----------------------------------------------------------------------===//
3646 // ViewOp
3647 //===----------------------------------------------------------------------===//
3648 
3649 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3650  setNameFn(getResult(), "view");
3651 }
3652 
3653 LogicalResult ViewOp::verify() {
3654  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3655  auto viewType = getType();
3656 
3657  // The base memref should have identity layout map (or none).
3658  if (!baseType.getLayout().isIdentity())
3659  return emitError("unsupported map for base memref type ") << baseType;
3660 
3661  // The result memref should have identity layout map (or none).
3662  if (!viewType.getLayout().isIdentity())
3663  return emitError("unsupported map for result memref type ") << viewType;
3664 
3665  // The base memref and the view memref should be in the same memory space.
3666  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3667  return emitError("different memory spaces specified for base memref "
3668  "type ")
3669  << baseType << " and view memref type " << viewType;
3670 
3671  // Verify that we have the correct number of sizes for the result type.
3672  unsigned numDynamicDims = viewType.getNumDynamicDims();
3673  if (getSizes().size() != numDynamicDims)
3674  return emitError("incorrect number of size operands for type ") << viewType;
3675 
3676  return success();
3677 }
3678 
3679 Value ViewOp::getViewSource() { return getSource(); }
3680 
3681 OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3682  MemRefType sourceMemrefType = getSource().getType();
3683  MemRefType resultMemrefType = getResult().getType();
3684 
3685  if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3686  return getViewSource();
3687 
3688  return {};
3689 }
3690 
3691 namespace {
3692 
3693 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3695 
3696  LogicalResult matchAndRewrite(ViewOp viewOp,
3697  PatternRewriter &rewriter) const override {
3698  // Return if none of the operands are constants.
3699  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3700  return matchPattern(operand, matchConstantIndex());
3701  }))
3702  return failure();
3703 
3704  // Get result memref type.
3705  auto memrefType = viewOp.getType();
3706 
3707  // Get offset from old memref view type 'memRefType'.
3708  int64_t oldOffset;
3709  SmallVector<int64_t, 4> oldStrides;
3710  if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3711  return failure();
3712  assert(oldOffset == 0 && "Expected 0 offset");
3713 
3714  SmallVector<Value, 4> newOperands;
3715 
3716  // Offset cannot be folded into result type.
3717 
3718  // Fold any dynamic dim operands which are produced by a constant.
3719  SmallVector<int64_t, 4> newShapeConstants;
3720  newShapeConstants.reserve(memrefType.getRank());
3721 
3722  unsigned dynamicDimPos = 0;
3723  unsigned rank = memrefType.getRank();
3724  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3725  int64_t dimSize = memrefType.getDimSize(dim);
3726  // If this is already static dimension, keep it.
3727  if (ShapedType::isStatic(dimSize)) {
3728  newShapeConstants.push_back(dimSize);
3729  continue;
3730  }
3731  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3732  if (auto constantIndexOp =
3733  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3734  // Dynamic shape dimension will be folded.
3735  newShapeConstants.push_back(constantIndexOp.value());
3736  } else {
3737  // Dynamic shape dimension not folded; copy operand from old memref.
3738  newShapeConstants.push_back(dimSize);
3739  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3740  }
3741  dynamicDimPos++;
3742  }
3743 
3744  // Create new memref type with constant folded dims.
3745  MemRefType newMemRefType =
3746  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3747  // Nothing new, don't fold.
3748  if (newMemRefType == memrefType)
3749  return failure();
3750 
3751  // Create new ViewOp.
3752  auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
3753  viewOp.getOperand(0), viewOp.getByteShift(),
3754  newOperands);
3755  // Insert a cast so we have the same type as the old memref type.
3756  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3757  return success();
3758  }
3759 };
3760 
3761 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3763 
3764  LogicalResult matchAndRewrite(ViewOp viewOp,
3765  PatternRewriter &rewriter) const override {
3766  Value memrefOperand = viewOp.getOperand(0);
3767  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3768  if (!memrefCastOp)
3769  return failure();
3770  Value allocOperand = memrefCastOp.getOperand();
3771  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3772  if (!allocOp)
3773  return failure();
3774  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3775  viewOp.getByteShift(),
3776  viewOp.getSizes());
3777  return success();
3778  }
3779 };
3780 
3781 } // namespace
3782 
3783 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3784  MLIRContext *context) {
3785  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3786 }
3787 
3788 FailureOr<std::optional<SmallVector<Value>>>
3789 ViewOp::bubbleDownCasts(OpBuilder &builder) {
3790  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3791 }
3792 
3793 //===----------------------------------------------------------------------===//
3794 // AtomicRMWOp
3795 //===----------------------------------------------------------------------===//
3796 
3797 LogicalResult AtomicRMWOp::verify() {
3798  if (getMemRefType().getRank() != getNumOperands() - 2)
3799  return emitOpError(
3800  "expects the number of subscripts to be equal to memref rank");
3801  switch (getKind()) {
3802  case arith::AtomicRMWKind::addf:
3803  case arith::AtomicRMWKind::maximumf:
3804  case arith::AtomicRMWKind::minimumf:
3805  case arith::AtomicRMWKind::mulf:
3806  if (!llvm::isa<FloatType>(getValue().getType()))
3807  return emitOpError() << "with kind '"
3808  << arith::stringifyAtomicRMWKind(getKind())
3809  << "' expects a floating-point type";
3810  break;
3811  case arith::AtomicRMWKind::addi:
3812  case arith::AtomicRMWKind::maxs:
3813  case arith::AtomicRMWKind::maxu:
3814  case arith::AtomicRMWKind::mins:
3815  case arith::AtomicRMWKind::minu:
3816  case arith::AtomicRMWKind::muli:
3817  case arith::AtomicRMWKind::ori:
3818  case arith::AtomicRMWKind::xori:
3819  case arith::AtomicRMWKind::andi:
3820  if (!llvm::isa<IntegerType>(getValue().getType()))
3821  return emitOpError() << "with kind '"
3822  << arith::stringifyAtomicRMWKind(getKind())
3823  << "' expects an integer type";
3824  break;
3825  default:
3826  break;
3827  }
3828  return success();
3829 }
3830 
3831 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3832  /// atomicrmw(memrefcast) -> atomicrmw
3833  if (succeeded(foldMemRefCast(*this, getValue())))
3834  return getResult();
3835  return OpFoldResult();
3836 }
3837 
3838 FailureOr<std::optional<SmallVector<Value>>>
3839 AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
3840  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
3841  getResult());
3842 }
3843 
3844 //===----------------------------------------------------------------------===//
3845 // TableGen'd op method definitions
3846 //===----------------------------------------------------------------------===//
3847 
3848 #define GET_OP_CLASSES
3849 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:51
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition: MemRefOps.cpp:96
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1570
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
Definition: MemRefOps.cpp:117
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2241
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:436
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3550
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:417
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:3044
static LogicalResult FoldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
Definition: MemRefOps.cpp:853
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1400
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
Definition: MemRefOps.cpp:3077
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3215
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1584
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:940
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3279
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:459
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
Definition: MemRefOps.cpp:148
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:3055
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:925
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2334
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2534
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:188
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:136
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition: Block.cpp:250
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:167
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:100
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
A set of arbitrary-precision integers representing bounds on a given integer value.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This lattice value represents the integer range of an SSA value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:193
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:227
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
A class that represents the strided metadata range information, including offsets,...
ArrayRef< ConstantIntRanges > getStrides() const
Get the strides ranges.
static StridedMetadataRange getRanked(SmallVectorImpl< ConstantIntRanges > &&offsets, SmallVectorImpl< ConstantIntRanges > &&sizes, SmallVectorImpl< ConstantIntRanges > &&strides)
Returns a ranked strided metadata range.
bool isUninitialized() const
Returns whether the metadata is uninitialized.
ArrayRef< ConstantIntRanges > getOffsets() const
Get the offsets range.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:68
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3247
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:318
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3182
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:23
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:507
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:510
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:467
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:470
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2702
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3431
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3432
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3394
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3395
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.