MLIR  22.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::memref;
29 
30 /// Materialize a single constant operation from a given attribute value with
31 /// the desired resultant type.
33  Attribute value, Type type,
34  Location loc) {
35  return arith::ConstantOp::materialize(builder, value, type, loc);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Common canonicalization pattern support logic
40 //===----------------------------------------------------------------------===//
41 
42 /// This is a common class used for patterns of the form
43 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
44 /// into the root operation directly.
45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46  bool folded = false;
47  for (OpOperand &operand : op->getOpOperands()) {
48  auto cast = operand.get().getDefiningOp<CastOp>();
49  if (cast && operand.get() != inner &&
50  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51  operand.set(cast.getOperand());
52  folded = true;
53  }
54  }
55  return success(folded);
56 }
57 
58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
59 /// type.
61  if (auto memref = llvm::dyn_cast<MemRefType>(type))
62  return RankedTensorType::get(memref.getShape(), memref.getElementType());
63  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
64  return UnrankedTensorType::get(memref.getElementType());
65  return NoneType::get(type.getContext());
66 }
67 
69  int64_t dim) {
70  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that sets values[i] to constValues[i] if the latter is a
91 /// static value, as indicated by ShapedType::kDynamic.
92 ///
93 /// If constValues[i] is dynamic, tries to extract a constant value from
94 /// value[i] to allow for additional folding opportunities. Also convertes all
95 /// existing attributes to index attributes. (They may be i64 attributes.)
97  ArrayRef<int64_t> constValues) {
98  assert(constValues.size() == values.size() &&
99  "incorrect number of const values");
100  for (auto [i, cstVal] : llvm::enumerate(constValues)) {
101  Builder builder(values[i].getContext());
102  if (ShapedType::isStatic(cstVal)) {
103  // Constant value is known, use it directly.
104  values[i] = builder.getIndexAttr(cstVal);
105  continue;
106  }
107  if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
108  // Try to extract a constant or convert an existing to index.
109  values[i] = builder.getIndexAttr(*cst);
110  }
111  }
112 }
113 
114 /// Helper function to retrieve a lossless memory-space cast, and the
115 /// corresponding new result memref type.
116 static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
118  MemorySpaceCastOpInterface castOp =
119  MemorySpaceCastOpInterface::getIfPromotableCast(src);
120 
121  // Bail if the cast is not lossless.
122  if (!castOp)
123  return {};
124 
125  // Transform the source and target type of `castOp` to have the same metadata
126  // as `resultTy`. Bail if not possible.
127  FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
128  castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
129  if (failed(srcTy))
130  return {};
131 
132  FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
133  castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
134  if (failed(tgtTy))
135  return {};
136 
137  // Check if this is a valid memory-space cast.
138  if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
139  return {};
140 
141  return std::make_tuple(castOp, *tgtTy, *srcTy);
142 }
143 
144 /// Implementation of `bubbleDownCasts` method for memref operations that
145 /// return a single memref result.
146 template <typename ConcreteOpTy>
147 static FailureOr<std::optional<SmallVector<Value>>>
148 bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
149  OpOperand &src) {
150  auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
151  // Bail if we cannot cast.
152  if (!castOp)
153  return failure();
154 
155  // Create the new operands.
156  SmallVector<Value> operands;
157  llvm::append_range(operands, op->getOperands());
158  operands[src.getOperandNumber()] = castOp.getSourcePtr();
159 
160  // Create the new op and results.
161  auto newOp = ConcreteOpTy::create(
162  builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
163  llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
164 
165  // Insert a memory-space cast to the original memory space of the op.
166  MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
167  builder, tgtTy,
168  cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
169  return std::optional<SmallVector<Value>>(
170  SmallVector<Value>({result.getTargetPtr()}));
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // AllocOp / AllocaOp
175 //===----------------------------------------------------------------------===//
176 
177 void AllocOp::getAsmResultNames(
178  function_ref<void(Value, StringRef)> setNameFn) {
179  setNameFn(getResult(), "alloc");
180 }
181 
182 void AllocaOp::getAsmResultNames(
183  function_ref<void(Value, StringRef)> setNameFn) {
184  setNameFn(getResult(), "alloca");
185 }
186 
187 template <typename AllocLikeOp>
188 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
189  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
190  "applies to only alloc or alloca");
191  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
192  if (!memRefType)
193  return op.emitOpError("result must be a memref");
194 
195  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
196  return op.emitOpError("dimension operand count does not equal memref "
197  "dynamic dimension count");
198 
199  unsigned numSymbols = 0;
200  if (!memRefType.getLayout().isIdentity())
201  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
202  if (op.getSymbolOperands().size() != numSymbols)
203  return op.emitOpError("symbol operand count does not equal memref symbol "
204  "count: expected ")
205  << numSymbols << ", got " << op.getSymbolOperands().size();
206 
207  return success();
208 }
209 
210 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
211 
212 LogicalResult AllocaOp::verify() {
213  // An alloca op needs to have an ancestor with an allocation scope trait.
214  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
215  return emitOpError(
216  "requires an ancestor op with AutomaticAllocationScope trait");
217 
218  return verifyAllocLikeOp(*this);
219 }
220 
221 namespace {
222 /// Fold constant dimensions into an alloc like operation.
223 template <typename AllocLikeOp>
224 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
226 
227  LogicalResult matchAndRewrite(AllocLikeOp alloc,
228  PatternRewriter &rewriter) const override {
229  // Check to see if any dimensions operands are constants. If so, we can
230  // substitute and drop them.
231  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
232  APInt constSizeArg;
233  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
234  return false;
235  return constSizeArg.isNonNegative();
236  }))
237  return failure();
238 
239  auto memrefType = alloc.getType();
240 
241  // Ok, we have one or more constant operands. Collect the non-constant ones
242  // and keep track of the resultant memref type to build.
243  SmallVector<int64_t, 4> newShapeConstants;
244  newShapeConstants.reserve(memrefType.getRank());
245  SmallVector<Value, 4> dynamicSizes;
246 
247  unsigned dynamicDimPos = 0;
248  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
249  int64_t dimSize = memrefType.getDimSize(dim);
250  // If this is already static dimension, keep it.
251  if (ShapedType::isStatic(dimSize)) {
252  newShapeConstants.push_back(dimSize);
253  continue;
254  }
255  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
256  APInt constSizeArg;
257  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
258  constSizeArg.isNonNegative()) {
259  // Dynamic shape dimension will be folded.
260  newShapeConstants.push_back(constSizeArg.getZExtValue());
261  } else {
262  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
263  newShapeConstants.push_back(ShapedType::kDynamic);
264  dynamicSizes.push_back(dynamicSize);
265  }
266  dynamicDimPos++;
267  }
268 
269  // Create new memref type (which will have fewer dynamic dimensions).
270  MemRefType newMemRefType =
271  MemRefType::Builder(memrefType).setShape(newShapeConstants);
272  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
273 
274  // Create and insert the alloc op for the new memref.
275  auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
276  dynamicSizes, alloc.getSymbolOperands(),
277  alloc.getAlignmentAttr());
278  // Insert a cast so we have the same type as the old alloc.
279  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
280  return success();
281  }
282 };
283 
284 /// Fold alloc operations with no users or only store and dealloc uses.
285 template <typename T>
286 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
288 
289  LogicalResult matchAndRewrite(T alloc,
290  PatternRewriter &rewriter) const override {
291  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
292  if (auto storeOp = dyn_cast<StoreOp>(op))
293  return storeOp.getValue() == alloc;
294  return !isa<DeallocOp>(op);
295  }))
296  return failure();
297 
298  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
299  rewriter.eraseOp(user);
300 
301  rewriter.eraseOp(alloc);
302  return success();
303  }
304 };
305 } // namespace
306 
307 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
308  MLIRContext *context) {
309  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
310 }
311 
312 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
313  MLIRContext *context) {
314  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
315  context);
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // ReallocOp
320 //===----------------------------------------------------------------------===//
321 
322 LogicalResult ReallocOp::verify() {
323  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
324  MemRefType resultType = getType();
325 
326  // The source memref should have identity layout (or none).
327  if (!sourceType.getLayout().isIdentity())
328  return emitError("unsupported layout for source memref type ")
329  << sourceType;
330 
331  // The result memref should have identity layout (or none).
332  if (!resultType.getLayout().isIdentity())
333  return emitError("unsupported layout for result memref type ")
334  << resultType;
335 
336  // The source memref and the result memref should be in the same memory space.
337  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
338  return emitError("different memory spaces specified for source memref "
339  "type ")
340  << sourceType << " and result memref type " << resultType;
341 
342  // The source memref and the result memref should have the same element type.
343  if (sourceType.getElementType() != resultType.getElementType())
344  return emitError("different element types specified for source memref "
345  "type ")
346  << sourceType << " and result memref type " << resultType;
347 
348  // Verify that we have the dynamic dimension operand when it is needed.
349  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350  return emitError("missing dimension operand for result type ")
351  << resultType;
352  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353  return emitError("unnecessary dimension operand for result type ")
354  << resultType;
355 
356  return success();
357 }
358 
359 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
360  MLIRContext *context) {
361  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // AllocaScopeOp
366 //===----------------------------------------------------------------------===//
367 
369  bool printBlockTerminators = false;
370 
371  p << ' ';
372  if (!getResults().empty()) {
373  p << " -> (" << getResultTypes() << ")";
374  printBlockTerminators = true;
375  }
376  p << ' ';
377  p.printRegion(getBodyRegion(),
378  /*printEntryBlockArgs=*/false,
379  /*printBlockTerminators=*/printBlockTerminators);
380  p.printOptionalAttrDict((*this)->getAttrs());
381 }
382 
383 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
384  // Create a region for the body.
385  result.regions.reserve(1);
386  Region *bodyRegion = result.addRegion();
387 
388  // Parse optional results type list.
389  if (parser.parseOptionalArrowTypeList(result.types))
390  return failure();
391 
392  // Parse the body region.
393  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
394  return failure();
395  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
396  result.location);
397 
398  // Parse the optional attribute list.
399  if (parser.parseOptionalAttrDict(result.attributes))
400  return failure();
401 
402  return success();
403 }
404 
405 void AllocaScopeOp::getSuccessorRegions(
407  if (!point.isParent()) {
408  regions.push_back(RegionSuccessor(getResults()));
409  return;
410  }
411 
412  regions.push_back(RegionSuccessor(&getBodyRegion()));
413 }
414 
415 /// Given an operation, return whether this op is guaranteed to
416 /// allocate an AutomaticAllocationScopeResource
418  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
419  if (!interface)
420  return false;
421  for (auto res : op->getResults()) {
422  if (auto effect =
423  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
424  if (isa<SideEffects::AutomaticAllocationScopeResource>(
425  effect->getResource()))
426  return true;
427  }
428  }
429  return false;
430 }
431 
432 /// Given an operation, return whether this op itself could
433 /// allocate an AutomaticAllocationScopeResource. Note that
434 /// this will not check whether an operation contained within
435 /// the op can allocate.
437  // This op itself doesn't create a stack allocation,
438  // the inner allocation should be handled separately.
440  return false;
441  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
442  if (!interface)
443  return true;
444  for (auto res : op->getResults()) {
445  if (auto effect =
446  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
447  if (isa<SideEffects::AutomaticAllocationScopeResource>(
448  effect->getResource()))
449  return true;
450  }
451  }
452  return false;
453 }
454 
455 /// Return whether this op is the last non terminating op
456 /// in a region. That is to say, it is in a one-block region
457 /// and is only followed by a terminator. This prevents
458 /// extending the lifetime of allocations.
460  return op->getBlock()->mightHaveTerminator() &&
461  op->getNextNode() == op->getBlock()->getTerminator() &&
462  op->getParentRegion()->hasOneBlock();
463 }
464 
465 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
466 /// or it contains no allocation.
467 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
469 
470  LogicalResult matchAndRewrite(AllocaScopeOp op,
471  PatternRewriter &rewriter) const override {
472  bool hasPotentialAlloca =
473  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
474  if (alloc == op)
475  return WalkResult::advance();
477  return WalkResult::interrupt();
478  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
479  return WalkResult::skip();
480  return WalkResult::advance();
481  }).wasInterrupted();
482 
483  // If this contains no potential allocation, it is always legal to
484  // inline. Otherwise, consider two conditions:
485  if (hasPotentialAlloca) {
486  // If the parent isn't an allocation scope, or we are not the last
487  // non-terminator op in the parent, we will extend the lifetime.
488  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
489  return failure();
490  if (!lastNonTerminatorInRegion(op))
491  return failure();
492  }
493 
494  Block *block = &op.getRegion().front();
495  Operation *terminator = block->getTerminator();
496  ValueRange results = terminator->getOperands();
497  rewriter.inlineBlockBefore(block, op);
498  rewriter.replaceOp(op, results);
499  rewriter.eraseOp(terminator);
500  return success();
501  }
502 };
503 
504 /// Move allocations into an allocation scope, if it is legal to
505 /// move them (e.g. their operands are available at the location
506 /// the op would be moved to).
507 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
509 
510  LogicalResult matchAndRewrite(AllocaScopeOp op,
511  PatternRewriter &rewriter) const override {
512 
513  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
514  return failure();
515 
516  Operation *lastParentWithoutScope = op->getParentOp();
517 
518  if (!lastParentWithoutScope ||
519  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
520  return failure();
521 
522  // Only apply to if this is this last non-terminator
523  // op in the block (lest lifetime be extended) of a one
524  // block region
525  if (!lastNonTerminatorInRegion(op) ||
526  !lastNonTerminatorInRegion(lastParentWithoutScope))
527  return failure();
528 
529  while (!lastParentWithoutScope->getParentOp()
531  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
532  if (!lastParentWithoutScope ||
533  !lastNonTerminatorInRegion(lastParentWithoutScope))
534  return failure();
535  }
536  assert(lastParentWithoutScope->getParentOp()
538 
539  Region *containingRegion = nullptr;
540  for (auto &r : lastParentWithoutScope->getRegions()) {
541  if (r.isAncestor(op->getParentRegion())) {
542  assert(containingRegion == nullptr &&
543  "only one region can contain the op");
544  containingRegion = &r;
545  }
546  }
547  assert(containingRegion && "op must be contained in a region");
548 
549  SmallVector<Operation *> toHoist;
550  op->walk([&](Operation *alloc) {
552  return WalkResult::skip();
553 
554  // If any operand is not defined before the location of
555  // lastParentWithoutScope (i.e. where we would hoist to), skip.
556  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
557  return containingRegion->isAncestor(v.getParentRegion());
558  }))
559  return WalkResult::skip();
560  toHoist.push_back(alloc);
561  return WalkResult::advance();
562  });
563 
564  if (toHoist.empty())
565  return failure();
566  rewriter.setInsertionPoint(lastParentWithoutScope);
567  for (auto *op : toHoist) {
568  auto *cloned = rewriter.clone(*op);
569  rewriter.replaceOp(op, cloned->getResults());
570  }
571  return success();
572  }
573 };
574 
575 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
576  MLIRContext *context) {
577  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
578 }
579 
580 //===----------------------------------------------------------------------===//
581 // AssumeAlignmentOp
582 //===----------------------------------------------------------------------===//
583 
584 LogicalResult AssumeAlignmentOp::verify() {
585  if (!llvm::isPowerOf2_32(getAlignment()))
586  return emitOpError("alignment must be power of 2");
587  return success();
588 }
589 
590 void AssumeAlignmentOp::getAsmResultNames(
591  function_ref<void(Value, StringRef)> setNameFn) {
592  setNameFn(getResult(), "assume_align");
593 }
594 
595 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
596  auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
597  if (!source)
598  return {};
599  if (source.getAlignment() != getAlignment())
600  return {};
601  return getMemref();
602 }
603 
604 FailureOr<std::optional<SmallVector<Value>>>
605 AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
606  return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // DistinctObjectsOp
611 //===----------------------------------------------------------------------===//
612 
613 LogicalResult DistinctObjectsOp::verify() {
614  if (getOperandTypes() != getResultTypes())
615  return emitOpError("operand types and result types must match");
616 
617  if (getOperandTypes().empty())
618  return emitOpError("expected at least one operand");
619 
620  return success();
621 }
622 
623 LogicalResult DistinctObjectsOp::inferReturnTypes(
624  MLIRContext * /*context*/, std::optional<Location> /*location*/,
625  ValueRange operands, DictionaryAttr /*attributes*/,
626  OpaqueProperties /*properties*/, RegionRange /*regions*/,
627  SmallVectorImpl<Type> &inferredReturnTypes) {
628  llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
629  return success();
630 }
631 
632 //===----------------------------------------------------------------------===//
633 // CastOp
634 //===----------------------------------------------------------------------===//
635 
636 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
637  setNameFn(getResult(), "cast");
638 }
639 
640 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
641 /// source memref. This is useful to fold a memref.cast into a consuming op
642 /// and implement canonicalization patterns for ops in different dialects that
643 /// may consume the results of memref.cast operations. Such foldable memref.cast
644 /// operations are typically inserted as `view` and `subview` ops are
645 /// canonicalized, to preserve the type compatibility of their uses.
646 ///
647 /// Returns true when all conditions are met:
648 /// 1. source and result are ranked memrefs with strided semantics and same
649 /// element type and rank.
650 /// 2. each of the source's size, offset or stride has more static information
651 /// than the corresponding result's size, offset or stride.
652 ///
653 /// Example 1:
654 /// ```mlir
655 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
656 /// %2 = consumer %1 ... : memref<?x?xf32> ...
657 /// ```
658 ///
659 /// may fold into:
660 ///
661 /// ```mlir
662 /// %2 = consumer %0 ... : memref<8x16xf32> ...
663 /// ```
664 ///
665 /// Example 2:
666 /// ```
667 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
668 /// to memref<?x?xf32>
669 /// consumer %1 : memref<?x?xf32> ...
670 /// ```
671 ///
672 /// may fold into:
673 ///
674 /// ```
675 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
676 /// ```
677 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
678  MemRefType sourceType =
679  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
680  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
681 
682  // Requires ranked MemRefType.
683  if (!sourceType || !resultType)
684  return false;
685 
686  // Requires same elemental type.
687  if (sourceType.getElementType() != resultType.getElementType())
688  return false;
689 
690  // Requires same rank.
691  if (sourceType.getRank() != resultType.getRank())
692  return false;
693 
694  // Only fold casts between strided memref forms.
695  int64_t sourceOffset, resultOffset;
696  SmallVector<int64_t, 4> sourceStrides, resultStrides;
697  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
698  failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
699  return false;
700 
701  // If cast is towards more static sizes along any dimension, don't fold.
702  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
703  auto ss = std::get<0>(it), st = std::get<1>(it);
704  if (ss != st)
705  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
706  return false;
707  }
708 
709  // If cast is towards more static offset along any dimension, don't fold.
710  if (sourceOffset != resultOffset)
711  if (ShapedType::isDynamic(sourceOffset) &&
712  ShapedType::isStatic(resultOffset))
713  return false;
714 
715  // If cast is towards more static strides along any dimension, don't fold.
716  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
717  auto ss = std::get<0>(it), st = std::get<1>(it);
718  if (ss != st)
719  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
720  return false;
721  }
722 
723  return true;
724 }
725 
726 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
727  if (inputs.size() != 1 || outputs.size() != 1)
728  return false;
729  Type a = inputs.front(), b = outputs.front();
730  auto aT = llvm::dyn_cast<MemRefType>(a);
731  auto bT = llvm::dyn_cast<MemRefType>(b);
732 
733  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
734  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
735 
736  if (aT && bT) {
737  if (aT.getElementType() != bT.getElementType())
738  return false;
739  if (aT.getLayout() != bT.getLayout()) {
740  int64_t aOffset, bOffset;
741  SmallVector<int64_t, 4> aStrides, bStrides;
742  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
743  failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
744  aStrides.size() != bStrides.size())
745  return false;
746 
747  // Strides along a dimension/offset are compatible if the value in the
748  // source memref is static and the value in the target memref is the
749  // same. They are also compatible if either one is dynamic (see
750  // description of MemRefCastOp for details).
751  auto checkCompatible = [](int64_t a, int64_t b) {
752  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
753  };
754  if (!checkCompatible(aOffset, bOffset))
755  return false;
756  for (const auto &aStride : enumerate(aStrides))
757  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
758  return false;
759  }
760  if (aT.getMemorySpace() != bT.getMemorySpace())
761  return false;
762 
763  // They must have the same rank, and any specified dimensions must match.
764  if (aT.getRank() != bT.getRank())
765  return false;
766 
767  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
768  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
769  if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
770  aDim != bDim)
771  return false;
772  }
773  return true;
774  } else {
775  if (!aT && !uaT)
776  return false;
777  if (!bT && !ubT)
778  return false;
779  // Unranked to unranked casting is unsupported
780  if (uaT && ubT)
781  return false;
782 
783  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
784  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
785  if (aEltType != bEltType)
786  return false;
787 
788  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
789  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
790  return aMemSpace == bMemSpace;
791  }
792 
793  return false;
794 }
795 
796 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
797  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
798 }
799 
800 FailureOr<std::optional<SmallVector<Value>>>
801 CastOp::bubbleDownCasts(OpBuilder &builder) {
802  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
803 }
804 
805 //===----------------------------------------------------------------------===//
806 // CopyOp
807 //===----------------------------------------------------------------------===//
808 
809 namespace {
810 
811 /// Fold memref.copy(%x, %x).
812 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
814 
815  LogicalResult matchAndRewrite(CopyOp copyOp,
816  PatternRewriter &rewriter) const override {
817  if (copyOp.getSource() != copyOp.getTarget())
818  return failure();
819 
820  rewriter.eraseOp(copyOp);
821  return success();
822  }
823 };
824 
825 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
827 
828  static bool isEmptyMemRef(BaseMemRefType type) {
829  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
830  }
831 
832  LogicalResult matchAndRewrite(CopyOp copyOp,
833  PatternRewriter &rewriter) const override {
834  if (isEmptyMemRef(copyOp.getSource().getType()) ||
835  isEmptyMemRef(copyOp.getTarget().getType())) {
836  rewriter.eraseOp(copyOp);
837  return success();
838  }
839 
840  return failure();
841  }
842 };
843 } // namespace
844 
845 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
846  MLIRContext *context) {
847  results.add<FoldEmptyCopy, FoldSelfCopy>(context);
848 }
849 
850 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
851 /// and element type, the cast can be skipped. Such CastOps only cast the layout
852 /// of the type.
853 static LogicalResult FoldCopyOfCast(CopyOp op) {
854  for (OpOperand &operand : op->getOpOperands()) {
855  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
856  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
857  operand.set(castOp.getOperand());
858  return success();
859  }
860  }
861  return failure();
862 }
863 
864 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
866 
867  /// copy(memrefcast) -> copy
868  return FoldCopyOfCast(*this);
869 }
870 
871 //===----------------------------------------------------------------------===//
872 // DeallocOp
873 //===----------------------------------------------------------------------===//
874 
875 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
877  /// dealloc(memrefcast) -> dealloc
878  return foldMemRefCast(*this);
879 }
880 
881 //===----------------------------------------------------------------------===//
882 // DimOp
883 //===----------------------------------------------------------------------===//
884 
885 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
886  setNameFn(getResult(), "dim");
887 }
888 
889 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
890  int64_t index) {
891  auto loc = result.location;
892  Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
893  build(builder, result, source, indexValue);
894 }
895 
896 std::optional<int64_t> DimOp::getConstantIndex() {
897  return getConstantIntValue(getIndex());
898 }
899 
900 Speculation::Speculatability DimOp::getSpeculatability() {
901  auto constantIndex = getConstantIndex();
902  if (!constantIndex)
904 
905  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
906  if (!rankedSourceType)
908 
909  if (rankedSourceType.getRank() <= constantIndex)
911 
913 }
914 
915 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
916  SetIntLatticeFn setResultRange) {
917  setResultRange(getResult(),
918  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
919 }
920 
921 /// Return a map with key being elements in `vals` and data being number of
922 /// occurences of it. Use std::map, since the `vals` here are strides and the
923 /// dynamic stride value is the same as the tombstone value for
924 /// `DenseMap<int64_t>`.
925 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
926  std::map<int64_t, unsigned> numOccurences;
927  for (auto val : vals)
928  numOccurences[val]++;
929  return numOccurences;
930 }
931 
932 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
933 /// to be a subset of `originalType` with some `1` entries erased, return the
934 /// set of indices that specifies which of the entries of `originalShape` are
935 /// dropped to obtain `reducedShape`.
936 /// This accounts for cases where there are multiple unit-dims, but only a
937 /// subset of those are dropped. For MemRefTypes these can be disambiguated
938 /// using the strides. If a dimension is dropped the stride must be dropped too.
939 static FailureOr<llvm::SmallBitVector>
940 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
941  ArrayRef<OpFoldResult> sizes) {
942  llvm::SmallBitVector unusedDims(originalType.getRank());
943  if (originalType.getRank() == reducedType.getRank())
944  return unusedDims;
945 
946  for (const auto &dim : llvm::enumerate(sizes))
947  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
948  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
949  unusedDims.set(dim.index());
950 
951  // Early exit for the case where the number of unused dims matches the number
952  // of ranks reduced.
953  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
954  originalType.getRank())
955  return unusedDims;
956 
957  SmallVector<int64_t> originalStrides, candidateStrides;
958  int64_t originalOffset, candidateOffset;
959  if (failed(
960  originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
961  failed(
962  reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
963  return failure();
964 
965  // For memrefs, a dimension is truly dropped if its corresponding stride is
966  // also dropped. This is particularly important when more than one of the dims
967  // is 1. Track the number of occurences of the strides in the original type
968  // and the candidate type. For each unused dim that stride should not be
969  // present in the candidate type. Note that there could be multiple dimensions
970  // that have the same size. We dont need to exactly figure out which dim
971  // corresponds to which stride, we just need to verify that the number of
972  // reptitions of a stride in the original + number of unused dims with that
973  // stride == number of repititions of a stride in the candidate.
974  std::map<int64_t, unsigned> currUnaccountedStrides =
975  getNumOccurences(originalStrides);
976  std::map<int64_t, unsigned> candidateStridesNumOccurences =
977  getNumOccurences(candidateStrides);
978  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
979  if (!unusedDims.test(dim))
980  continue;
981  int64_t originalStride = originalStrides[dim];
982  if (currUnaccountedStrides[originalStride] >
983  candidateStridesNumOccurences[originalStride]) {
984  // This dim can be treated as dropped.
985  currUnaccountedStrides[originalStride]--;
986  continue;
987  }
988  if (currUnaccountedStrides[originalStride] ==
989  candidateStridesNumOccurences[originalStride]) {
990  // The stride for this is not dropped. Keep as is.
991  unusedDims.reset(dim);
992  continue;
993  }
994  if (currUnaccountedStrides[originalStride] <
995  candidateStridesNumOccurences[originalStride]) {
996  // This should never happen. Cant have a stride in the reduced rank type
997  // that wasnt in the original one.
998  return failure();
999  }
1000  }
1001 
1002  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1003  originalType.getRank())
1004  return failure();
1005  return unusedDims;
1006 }
1007 
1008 llvm::SmallBitVector SubViewOp::getDroppedDims() {
1009  MemRefType sourceType = getSourceType();
1010  MemRefType resultType = getType();
1011  FailureOr<llvm::SmallBitVector> unusedDims =
1012  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1013  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1014  return *unusedDims;
1015 }
1016 
1017 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1018  // All forms of folding require a known index.
1019  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1020  if (!index)
1021  return {};
1022 
1023  // Folding for unranked types (UnrankedMemRefType) is not supported.
1024  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1025  if (!memrefType)
1026  return {};
1027 
1028  // Out of bound indices produce undefined behavior but are still valid IR.
1029  // Don't choke on them.
1030  int64_t indexVal = index.getInt();
1031  if (indexVal < 0 || indexVal >= memrefType.getRank())
1032  return {};
1033 
1034  // Fold if the shape extent along the given index is known.
1035  if (!memrefType.isDynamicDim(index.getInt())) {
1036  Builder builder(getContext());
1037  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1038  }
1039 
1040  // The size at the given index is now known to be a dynamic size.
1041  unsigned unsignedIndex = index.getValue().getZExtValue();
1042 
1043  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1044  Operation *definingOp = getSource().getDefiningOp();
1045 
1046  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1047  return *(alloc.getDynamicSizes().begin() +
1048  memrefType.getDynamicDimIndex(unsignedIndex));
1049 
1050  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1051  return *(alloca.getDynamicSizes().begin() +
1052  memrefType.getDynamicDimIndex(unsignedIndex));
1053 
1054  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1055  return *(view.getDynamicSizes().begin() +
1056  memrefType.getDynamicDimIndex(unsignedIndex));
1057 
1058  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1059  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1060  unsigned resultIndex = 0;
1061  unsigned sourceRank = subview.getSourceType().getRank();
1062  unsigned sourceIndex = 0;
1063  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1064  if (unusedDims.test(i))
1065  continue;
1066  if (resultIndex == unsignedIndex) {
1067  sourceIndex = i;
1068  break;
1069  }
1070  resultIndex++;
1071  }
1072  assert(subview.isDynamicSize(sourceIndex) &&
1073  "expected dynamic subview size");
1074  return subview.getDynamicSize(sourceIndex);
1075  }
1076 
1077  if (auto sizeInterface =
1078  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1079  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1080  "Expected dynamic subview size");
1081  return sizeInterface.getDynamicSize(unsignedIndex);
1082  }
1083 
1084  // dim(memrefcast) -> dim
1085  if (succeeded(foldMemRefCast(*this)))
1086  return getResult();
1087 
1088  return {};
1089 }
1090 
1091 namespace {
1092 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1093 /// operand.
1094 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1096 
1097  LogicalResult matchAndRewrite(DimOp dim,
1098  PatternRewriter &rewriter) const override {
1099  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1100 
1101  if (!reshape)
1102  return rewriter.notifyMatchFailure(
1103  dim, "Dim op is not defined by a reshape op.");
1104 
1105  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1106  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1107  // cheaply check that either of the following conditions hold:
1108  // 1. dim.getIndex() is defined in the same block as reshape but before
1109  // reshape.
1110  // 2. dim.getIndex() is defined in a parent block of
1111  // reshape.
1112 
1113  // Check condition 1
1114  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1115  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1116  if (reshape->isBeforeInBlock(definingOp)) {
1117  return rewriter.notifyMatchFailure(
1118  dim,
1119  "dim.getIndex is not defined before reshape in the same block.");
1120  }
1121  } // else dim.getIndex is a block argument to reshape->getBlock and
1122  // dominates reshape
1123  } // Check condition 2
1124  else if (dim->getBlock() != reshape->getBlock() &&
1125  !dim.getIndex().getParentRegion()->isProperAncestor(
1126  reshape->getParentRegion())) {
1127  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1128  // already know dim.getIndex() dominates reshape without calling
1129  // `isProperAncestor`
1130  return rewriter.notifyMatchFailure(
1131  dim, "dim.getIndex does not dominate reshape.");
1132  }
1133 
1134  // Place the load directly after the reshape to ensure that the shape memref
1135  // was not mutated.
1136  rewriter.setInsertionPointAfter(reshape);
1137  Location loc = dim.getLoc();
1138  Value load =
1139  LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1140  if (load.getType() != dim.getType())
1141  load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1142  rewriter.replaceOp(dim, load);
1143  return success();
1144  }
1145 };
1146 
1147 } // namespace
1148 
1149 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1150  MLIRContext *context) {
1151  results.add<DimOfMemRefReshape>(context);
1152 }
1153 
1154 // ---------------------------------------------------------------------------
1155 // DmaStartOp
1156 // ---------------------------------------------------------------------------
1157 
1158 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1159  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1160  ValueRange destIndices, Value numElements,
1161  Value tagMemRef, ValueRange tagIndices, Value stride,
1162  Value elementsPerStride) {
1163  result.addOperands(srcMemRef);
1164  result.addOperands(srcIndices);
1165  result.addOperands(destMemRef);
1166  result.addOperands(destIndices);
1167  result.addOperands({numElements, tagMemRef});
1168  result.addOperands(tagIndices);
1169  if (stride)
1170  result.addOperands({stride, elementsPerStride});
1171 }
1172 
1174  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1175  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1176  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1177  if (isStrided())
1178  p << ", " << getStride() << ", " << getNumElementsPerStride();
1179 
1180  p.printOptionalAttrDict((*this)->getAttrs());
1181  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1182  << ", " << getTagMemRef().getType();
1183 }
1184 
1185 // Parse DmaStartOp.
1186 // Ex:
1187 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1188 // %tag[%index], %stride, %num_elt_per_stride :
1189 // : memref<3076 x f32, 0>,
1190 // memref<1024 x f32, 2>,
1191 // memref<1 x i32>
1192 //
1193 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1194  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1196  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1198  OpAsmParser::UnresolvedOperand numElementsInfo;
1199  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1202 
1203  SmallVector<Type, 3> types;
1204  auto indexType = parser.getBuilder().getIndexType();
1205 
1206  // Parse and resolve the following list of operands:
1207  // *) source memref followed by its indices (in square brackets).
1208  // *) destination memref followed by its indices (in square brackets).
1209  // *) dma size in KiB.
1210  if (parser.parseOperand(srcMemRefInfo) ||
1211  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1212  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1213  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1214  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1215  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1216  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1217  return failure();
1218 
1219  // Parse optional stride and elements per stride.
1220  if (parser.parseTrailingOperandList(strideInfo))
1221  return failure();
1222 
1223  bool isStrided = strideInfo.size() == 2;
1224  if (!strideInfo.empty() && !isStrided) {
1225  return parser.emitError(parser.getNameLoc(),
1226  "expected two stride related operands");
1227  }
1228 
1229  if (parser.parseColonTypeList(types))
1230  return failure();
1231  if (types.size() != 3)
1232  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1233 
1234  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1235  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1236  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1237  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1238  // size should be an index.
1239  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1240  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1241  // tag indices should be index.
1242  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1243  return failure();
1244 
1245  if (isStrided) {
1246  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1247  return failure();
1248  }
1249 
1250  return success();
1251 }
1252 
1253 LogicalResult DmaStartOp::verify() {
1254  unsigned numOperands = getNumOperands();
1255 
1256  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1257  // the number of elements.
1258  if (numOperands < 4)
1259  return emitOpError("expected at least 4 operands");
1260 
1261  // Check types of operands. The order of these calls is important: the later
1262  // calls rely on some type properties to compute the operand position.
1263  // 1. Source memref.
1264  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1265  return emitOpError("expected source to be of memref type");
1266  if (numOperands < getSrcMemRefRank() + 4)
1267  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1268  << " operands";
1269  if (!getSrcIndices().empty() &&
1270  !llvm::all_of(getSrcIndices().getTypes(),
1271  [](Type t) { return t.isIndex(); }))
1272  return emitOpError("expected source indices to be of index type");
1273 
1274  // 2. Destination memref.
1275  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1276  return emitOpError("expected destination to be of memref type");
1277  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1278  if (numOperands < numExpectedOperands)
1279  return emitOpError() << "expected at least " << numExpectedOperands
1280  << " operands";
1281  if (!getDstIndices().empty() &&
1282  !llvm::all_of(getDstIndices().getTypes(),
1283  [](Type t) { return t.isIndex(); }))
1284  return emitOpError("expected destination indices to be of index type");
1285 
1286  // 3. Number of elements.
1287  if (!getNumElements().getType().isIndex())
1288  return emitOpError("expected num elements to be of index type");
1289 
1290  // 4. Tag memref.
1291  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1292  return emitOpError("expected tag to be of memref type");
1293  numExpectedOperands += getTagMemRefRank();
1294  if (numOperands < numExpectedOperands)
1295  return emitOpError() << "expected at least " << numExpectedOperands
1296  << " operands";
1297  if (!getTagIndices().empty() &&
1298  !llvm::all_of(getTagIndices().getTypes(),
1299  [](Type t) { return t.isIndex(); }))
1300  return emitOpError("expected tag indices to be of index type");
1301 
1302  // Optional stride-related operands must be either both present or both
1303  // absent.
1304  if (numOperands != numExpectedOperands &&
1305  numOperands != numExpectedOperands + 2)
1306  return emitOpError("incorrect number of operands");
1307 
1308  // 5. Strides.
1309  if (isStrided()) {
1310  if (!getStride().getType().isIndex() ||
1311  !getNumElementsPerStride().getType().isIndex())
1312  return emitOpError(
1313  "expected stride and num elements per stride to be of type index");
1314  }
1315 
1316  return success();
1317 }
1318 
1319 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1320  SmallVectorImpl<OpFoldResult> &results) {
1321  /// dma_start(memrefcast) -> dma_start
1322  return foldMemRefCast(*this);
1323 }
1324 
1325 // ---------------------------------------------------------------------------
1326 // DmaWaitOp
1327 // ---------------------------------------------------------------------------
1328 
1329 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1330  SmallVectorImpl<OpFoldResult> &results) {
1331  /// dma_wait(memrefcast) -> dma_wait
1332  return foldMemRefCast(*this);
1333 }
1334 
1335 LogicalResult DmaWaitOp::verify() {
1336  // Check that the number of tag indices matches the tagMemRef rank.
1337  unsigned numTagIndices = getTagIndices().size();
1338  unsigned tagMemRefRank = getTagMemRefRank();
1339  if (numTagIndices != tagMemRefRank)
1340  return emitOpError() << "expected tagIndices to have the same number of "
1341  "elements as the tagMemRef rank, expected "
1342  << tagMemRefRank << ", but got " << numTagIndices;
1343  return success();
1344 }
1345 
1346 //===----------------------------------------------------------------------===//
1347 // ExtractAlignedPointerAsIndexOp
1348 //===----------------------------------------------------------------------===//
1349 
1350 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1351  function_ref<void(Value, StringRef)> setNameFn) {
1352  setNameFn(getResult(), "intptr");
1353 }
1354 
1355 //===----------------------------------------------------------------------===//
1356 // ExtractStridedMetadataOp
1357 //===----------------------------------------------------------------------===//
1358 
1359 /// The number and type of the results are inferred from the
1360 /// shape of the source.
1361 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1362  MLIRContext *context, std::optional<Location> location,
1363  ExtractStridedMetadataOp::Adaptor adaptor,
1364  SmallVectorImpl<Type> &inferredReturnTypes) {
1365  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1366  if (!sourceType)
1367  return failure();
1368 
1369  unsigned sourceRank = sourceType.getRank();
1370  IndexType indexType = IndexType::get(context);
1371  auto memrefType =
1372  MemRefType::get({}, sourceType.getElementType(),
1373  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1374  // Base.
1375  inferredReturnTypes.push_back(memrefType);
1376  // Offset.
1377  inferredReturnTypes.push_back(indexType);
1378  // Sizes and strides.
1379  for (unsigned i = 0; i < sourceRank * 2; ++i)
1380  inferredReturnTypes.push_back(indexType);
1381  return success();
1382 }
1383 
1384 void ExtractStridedMetadataOp::getAsmResultNames(
1385  function_ref<void(Value, StringRef)> setNameFn) {
1386  setNameFn(getBaseBuffer(), "base_buffer");
1387  setNameFn(getOffset(), "offset");
1388  // For multi-result to work properly with pretty names and packed syntax `x:3`
1389  // we can only give a pretty name to the first value in the pack.
1390  if (!getSizes().empty()) {
1391  setNameFn(getSizes().front(), "sizes");
1392  setNameFn(getStrides().front(), "strides");
1393  }
1394 }
1395 
1396 /// Helper function to perform the replacement of all constant uses of `values`
1397 /// by a materialized constant extracted from `maybeConstants`.
1398 /// `values` and `maybeConstants` are expected to have the same size.
1399 template <typename Container>
1400 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1401  Container values,
1402  ArrayRef<OpFoldResult> maybeConstants) {
1403  assert(values.size() == maybeConstants.size() &&
1404  " expected values and maybeConstants of the same size");
1405  bool atLeastOneReplacement = false;
1406  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1407  // Don't materialize a constant if there are no uses: this would indice
1408  // infinite loops in the driver.
1409  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1410  continue;
1411  assert(isa<Attribute>(maybeConstant) &&
1412  "The constified value should be either unchanged (i.e., == result) "
1413  "or a constant");
1414  Value constantVal = arith::ConstantIndexOp::create(
1415  rewriter, loc,
1416  llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1417  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1418  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1419  // yet.
1420  op->replaceUsesOfWith(result, constantVal);
1421  atLeastOneReplacement = true;
1422  }
1423  }
1424  return atLeastOneReplacement;
1425 }
1426 
1427 LogicalResult
1428 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1429  SmallVectorImpl<OpFoldResult> &results) {
1430  OpBuilder builder(*this);
1431 
1432  bool atLeastOneReplacement = replaceConstantUsesOf(
1433  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1434  getConstifiedMixedOffset());
1435  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1436  getConstifiedMixedSizes());
1437  atLeastOneReplacement |= replaceConstantUsesOf(
1438  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1439 
1440  return success(atLeastOneReplacement);
1441 }
1442 
1443 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1444  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1445  constifyIndexValues(values, getSource().getType().getShape());
1446  return values;
1447 }
1448 
1450 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1451  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1452  SmallVector<int64_t> staticValues;
1453  int64_t unused;
1454  LogicalResult status =
1455  getSource().getType().getStridesAndOffset(staticValues, unused);
1456  (void)status;
1457  assert(succeeded(status) && "could not get strides from type");
1458  constifyIndexValues(values, staticValues);
1459  return values;
1460 }
1461 
1462 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1463  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1464  SmallVector<OpFoldResult> values(1, offsetOfr);
1465  SmallVector<int64_t> staticValues, unused;
1466  int64_t offset;
1467  LogicalResult status =
1468  getSource().getType().getStridesAndOffset(unused, offset);
1469  (void)status;
1470  assert(succeeded(status) && "could not get offset from type");
1471  staticValues.push_back(offset);
1472  constifyIndexValues(values, staticValues);
1473  return values[0];
1474 }
1475 
1476 //===----------------------------------------------------------------------===//
1477 // GenericAtomicRMWOp
1478 //===----------------------------------------------------------------------===//
1479 
1480 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1481  Value memref, ValueRange ivs) {
1482  OpBuilder::InsertionGuard g(builder);
1483  result.addOperands(memref);
1484  result.addOperands(ivs);
1485 
1486  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1487  Type elementType = memrefType.getElementType();
1488  result.addTypes(elementType);
1489 
1490  Region *bodyRegion = result.addRegion();
1491  builder.createBlock(bodyRegion);
1492  bodyRegion->addArgument(elementType, memref.getLoc());
1493  }
1494 }
1495 
1496 LogicalResult GenericAtomicRMWOp::verify() {
1497  auto &body = getRegion();
1498  if (body.getNumArguments() != 1)
1499  return emitOpError("expected single number of entry block arguments");
1500 
1501  if (getResult().getType() != body.getArgument(0).getType())
1502  return emitOpError("expected block argument of the same type result type");
1503 
1504  bool hasSideEffects =
1505  body.walk([&](Operation *nestedOp) {
1506  if (isMemoryEffectFree(nestedOp))
1507  return WalkResult::advance();
1508  nestedOp->emitError(
1509  "body of 'memref.generic_atomic_rmw' should contain "
1510  "only operations with no side effects");
1511  return WalkResult::interrupt();
1512  })
1513  .wasInterrupted();
1514  return hasSideEffects ? failure() : success();
1515 }
1516 
1517 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1518  OperationState &result) {
1520  Type memrefType;
1522 
1523  Type indexType = parser.getBuilder().getIndexType();
1524  if (parser.parseOperand(memref) ||
1526  parser.parseColonType(memrefType) ||
1527  parser.resolveOperand(memref, memrefType, result.operands) ||
1528  parser.resolveOperands(ivs, indexType, result.operands))
1529  return failure();
1530 
1531  Region *body = result.addRegion();
1532  if (parser.parseRegion(*body, {}) ||
1533  parser.parseOptionalAttrDict(result.attributes))
1534  return failure();
1535  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1536  return success();
1537 }
1538 
1540  p << ' ' << getMemref() << "[" << getIndices()
1541  << "] : " << getMemref().getType() << ' ';
1542  p.printRegion(getRegion());
1543  p.printOptionalAttrDict((*this)->getAttrs());
1544 }
1545 
1546 //===----------------------------------------------------------------------===//
1547 // AtomicYieldOp
1548 //===----------------------------------------------------------------------===//
1549 
1550 LogicalResult AtomicYieldOp::verify() {
1551  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1552  Type resultType = getResult().getType();
1553  if (parentType != resultType)
1554  return emitOpError() << "types mismatch between yield op: " << resultType
1555  << " and its parent: " << parentType;
1556  return success();
1557 }
1558 
1559 //===----------------------------------------------------------------------===//
1560 // GlobalOp
1561 //===----------------------------------------------------------------------===//
1562 
1564  TypeAttr type,
1565  Attribute initialValue) {
1566  p << type;
1567  if (!op.isExternal()) {
1568  p << " = ";
1569  if (op.isUninitialized())
1570  p << "uninitialized";
1571  else
1572  p.printAttributeWithoutType(initialValue);
1573  }
1574 }
1575 
1576 static ParseResult
1578  Attribute &initialValue) {
1579  Type type;
1580  if (parser.parseType(type))
1581  return failure();
1582 
1583  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1584  if (!memrefType || !memrefType.hasStaticShape())
1585  return parser.emitError(parser.getNameLoc())
1586  << "type should be static shaped memref, but got " << type;
1587  typeAttr = TypeAttr::get(type);
1588 
1589  if (parser.parseOptionalEqual())
1590  return success();
1591 
1592  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1593  initialValue = UnitAttr::get(parser.getContext());
1594  return success();
1595  }
1596 
1597  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1598  if (parser.parseAttribute(initialValue, tensorType))
1599  return failure();
1600  if (!llvm::isa<ElementsAttr>(initialValue))
1601  return parser.emitError(parser.getNameLoc())
1602  << "initial value should be a unit or elements attribute";
1603  return success();
1604 }
1605 
1606 LogicalResult GlobalOp::verify() {
1607  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1608  if (!memrefType || !memrefType.hasStaticShape())
1609  return emitOpError("type should be static shaped memref, but got ")
1610  << getType();
1611 
1612  // Verify that the initial value, if present, is either a unit attribute or
1613  // an elements attribute.
1614  if (getInitialValue().has_value()) {
1615  Attribute initValue = getInitialValue().value();
1616  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1617  return emitOpError("initial value should be a unit or elements "
1618  "attribute, but got ")
1619  << initValue;
1620 
1621  // Check that the type of the initial value is compatible with the type of
1622  // the global variable.
1623  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1624  // Check the element types match.
1625  auto initElementType =
1626  cast<TensorType>(elementsAttr.getType()).getElementType();
1627  auto memrefElementType = memrefType.getElementType();
1628 
1629  if (initElementType != memrefElementType)
1630  return emitOpError("initial value element expected to be of type ")
1631  << memrefElementType << ", but was of type " << initElementType;
1632 
1633  // Check the shapes match, given that memref globals can only produce
1634  // statically shaped memrefs and elements literal type must have a static
1635  // shape we can assume both types are shaped.
1636  auto initShape = elementsAttr.getShapedType().getShape();
1637  auto memrefShape = memrefType.getShape();
1638  if (initShape != memrefShape)
1639  return emitOpError("initial value shape expected to be ")
1640  << memrefShape << " but was " << initShape;
1641  }
1642  }
1643 
1644  // TODO: verify visibility for declarations.
1645  return success();
1646 }
1647 
1648 ElementsAttr GlobalOp::getConstantInitValue() {
1649  auto initVal = getInitialValue();
1650  if (getConstant() && initVal.has_value())
1651  return llvm::cast<ElementsAttr>(initVal.value());
1652  return {};
1653 }
1654 
1655 //===----------------------------------------------------------------------===//
1656 // GetGlobalOp
1657 //===----------------------------------------------------------------------===//
1658 
1659 LogicalResult
1660 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1661  // Verify that the result type is same as the type of the referenced
1662  // memref.global op.
1663  auto global =
1664  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1665  if (!global)
1666  return emitOpError("'")
1667  << getName() << "' does not reference a valid global memref";
1668 
1669  Type resultType = getResult().getType();
1670  if (global.getType() != resultType)
1671  return emitOpError("result type ")
1672  << resultType << " does not match type " << global.getType()
1673  << " of the global memref @" << getName();
1674  return success();
1675 }
1676 
1677 //===----------------------------------------------------------------------===//
1678 // LoadOp
1679 //===----------------------------------------------------------------------===//
1680 
1681 LogicalResult LoadOp::verify() {
1682  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1683  return emitOpError("incorrect number of indices for load, expected ")
1684  << getMemRefType().getRank() << " but got " << getIndices().size();
1685  }
1686  return success();
1687 }
1688 
1689 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1690  /// load(memrefcast) -> load
1691  if (succeeded(foldMemRefCast(*this)))
1692  return getResult();
1693  return OpFoldResult();
1694 }
1695 
1696 FailureOr<std::optional<SmallVector<Value>>>
1697 LoadOp::bubbleDownCasts(OpBuilder &builder) {
1698  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
1699  getResult());
1700 }
1701 
1702 //===----------------------------------------------------------------------===//
1703 // MemorySpaceCastOp
1704 //===----------------------------------------------------------------------===//
1705 
1706 void MemorySpaceCastOp::getAsmResultNames(
1707  function_ref<void(Value, StringRef)> setNameFn) {
1708  setNameFn(getResult(), "memspacecast");
1709 }
1710 
1711 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1712  if (inputs.size() != 1 || outputs.size() != 1)
1713  return false;
1714  Type a = inputs.front(), b = outputs.front();
1715  auto aT = llvm::dyn_cast<MemRefType>(a);
1716  auto bT = llvm::dyn_cast<MemRefType>(b);
1717 
1718  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1719  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1720 
1721  if (aT && bT) {
1722  if (aT.getElementType() != bT.getElementType())
1723  return false;
1724  if (aT.getLayout() != bT.getLayout())
1725  return false;
1726  if (aT.getShape() != bT.getShape())
1727  return false;
1728  return true;
1729  }
1730  if (uaT && ubT) {
1731  return uaT.getElementType() == ubT.getElementType();
1732  }
1733  return false;
1734 }
1735 
1736 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1737  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1738  // t2)
1739  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1740  getSourceMutable().assign(parentCast.getSource());
1741  return getResult();
1742  }
1743  return Value{};
1744 }
1745 
1746 TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
1747  return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
1748 }
1749 
1750 TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
1751  return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
1752 }
1753 
1754 bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1755  PtrLikeTypeInterface src) {
1756  return isa<BaseMemRefType>(tgt) &&
1757  tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1758 }
1759 
1760 MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1761  OpBuilder &b, PtrLikeTypeInterface tgt,
1763  assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
1764  return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
1765 }
1766 
1767 /// The only cast we recognize as promotable is to the generic space.
1768 bool MemorySpaceCastOp::isSourcePromotable() {
1769  return getDest().getType().getMemorySpace() == nullptr;
1770 }
1771 
1772 //===----------------------------------------------------------------------===//
1773 // PrefetchOp
1774 //===----------------------------------------------------------------------===//
1775 
1777  p << " " << getMemref() << '[';
1779  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1780  p << ", locality<" << getLocalityHint();
1781  p << ">, " << (getIsDataCache() ? "data" : "instr");
1783  (*this)->getAttrs(),
1784  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1785  p << " : " << getMemRefType();
1786 }
1787 
1788 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1789  OpAsmParser::UnresolvedOperand memrefInfo;
1791  IntegerAttr localityHint;
1792  MemRefType type;
1793  StringRef readOrWrite, cacheType;
1794 
1795  auto indexTy = parser.getBuilder().getIndexType();
1796  auto i32Type = parser.getBuilder().getIntegerType(32);
1797  if (parser.parseOperand(memrefInfo) ||
1798  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1799  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1800  parser.parseComma() || parser.parseKeyword("locality") ||
1801  parser.parseLess() ||
1802  parser.parseAttribute(localityHint, i32Type, "localityHint",
1803  result.attributes) ||
1804  parser.parseGreater() || parser.parseComma() ||
1805  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1806  parser.resolveOperand(memrefInfo, type, result.operands) ||
1807  parser.resolveOperands(indexInfo, indexTy, result.operands))
1808  return failure();
1809 
1810  if (readOrWrite != "read" && readOrWrite != "write")
1811  return parser.emitError(parser.getNameLoc(),
1812  "rw specifier has to be 'read' or 'write'");
1813  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1814  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1815 
1816  if (cacheType != "data" && cacheType != "instr")
1817  return parser.emitError(parser.getNameLoc(),
1818  "cache type has to be 'data' or 'instr'");
1819 
1820  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1821  parser.getBuilder().getBoolAttr(cacheType == "data"));
1822 
1823  return success();
1824 }
1825 
1826 LogicalResult PrefetchOp::verify() {
1827  if (getNumOperands() != 1 + getMemRefType().getRank())
1828  return emitOpError("too few indices");
1829 
1830  return success();
1831 }
1832 
1833 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1834  SmallVectorImpl<OpFoldResult> &results) {
1835  // prefetch(memrefcast) -> prefetch
1836  return foldMemRefCast(*this);
1837 }
1838 
1839 //===----------------------------------------------------------------------===//
1840 // RankOp
1841 //===----------------------------------------------------------------------===//
1842 
1843 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1844  // Constant fold rank when the rank of the operand is known.
1845  auto type = getOperand().getType();
1846  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1847  if (shapedType && shapedType.hasRank())
1848  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1849  return IntegerAttr();
1850 }
1851 
1852 //===----------------------------------------------------------------------===//
1853 // ReinterpretCastOp
1854 //===----------------------------------------------------------------------===//
1855 
1856 void ReinterpretCastOp::getAsmResultNames(
1857  function_ref<void(Value, StringRef)> setNameFn) {
1858  setNameFn(getResult(), "reinterpret_cast");
1859 }
1860 
1861 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1862 /// `staticSizes` and `staticStrides` are automatically filled with
1863 /// source-memref-rank sentinel values that encode dynamic entries.
1864 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1865  MemRefType resultType, Value source,
1866  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1867  ArrayRef<OpFoldResult> strides,
1868  ArrayRef<NamedAttribute> attrs) {
1869  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1870  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1871  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1872  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1873  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1874  result.addAttributes(attrs);
1875  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1876  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1877  b.getDenseI64ArrayAttr(staticSizes),
1878  b.getDenseI64ArrayAttr(staticStrides));
1879 }
1880 
1881 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1882  Value source, OpFoldResult offset,
1883  ArrayRef<OpFoldResult> sizes,
1884  ArrayRef<OpFoldResult> strides,
1885  ArrayRef<NamedAttribute> attrs) {
1886  auto sourceType = cast<BaseMemRefType>(source.getType());
1887  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1888  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1889  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1890  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1891  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1892  auto stridedLayout = StridedLayoutAttr::get(
1893  b.getContext(), staticOffsets.front(), staticStrides);
1894  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1895  stridedLayout, sourceType.getMemorySpace());
1896  build(b, result, resultType, source, offset, sizes, strides, attrs);
1897 }
1898 
1899 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1900  MemRefType resultType, Value source,
1901  int64_t offset, ArrayRef<int64_t> sizes,
1902  ArrayRef<int64_t> strides,
1903  ArrayRef<NamedAttribute> attrs) {
1904  SmallVector<OpFoldResult> sizeValues =
1905  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1906  return b.getI64IntegerAttr(v);
1907  }));
1908  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1909  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1910  return b.getI64IntegerAttr(v);
1911  }));
1912  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1913  strideValues, attrs);
1914 }
1915 
1916 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1917  MemRefType resultType, Value source, Value offset,
1918  ValueRange sizes, ValueRange strides,
1919  ArrayRef<NamedAttribute> attrs) {
1920  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1921  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1922  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1923  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1924  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1925 }
1926 
1927 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1928 // completed automatically, like we have for subview and extract_slice.
1929 LogicalResult ReinterpretCastOp::verify() {
1930  // The source and result memrefs should be in the same memory space.
1931  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1932  auto resultType = llvm::cast<MemRefType>(getType());
1933  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1934  return emitError("different memory spaces specified for source type ")
1935  << srcType << " and result memref type " << resultType;
1936  if (srcType.getElementType() != resultType.getElementType())
1937  return emitError("different element types specified for source type ")
1938  << srcType << " and result memref type " << resultType;
1939 
1940  // Match sizes in result memref type and in static_sizes attribute.
1941  for (auto [idx, resultSize, expectedSize] :
1942  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1943  if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1944  return emitError("expected result type with size = ")
1945  << (ShapedType::isDynamic(expectedSize)
1946  ? std::string("dynamic")
1947  : std::to_string(expectedSize))
1948  << " instead of " << resultSize << " in dim = " << idx;
1949  }
1950 
1951  // Match offset and strides in static_offset and static_strides attributes. If
1952  // result memref type has no affine map specified, this will assume an
1953  // identity layout.
1954  int64_t resultOffset;
1955  SmallVector<int64_t, 4> resultStrides;
1956  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1957  return emitError("expected result type to have strided layout but found ")
1958  << resultType;
1959 
1960  // Match offset in result memref type and in static_offsets attribute.
1961  int64_t expectedOffset = getStaticOffsets().front();
1962  if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1963  return emitError("expected result type with offset = ")
1964  << (ShapedType::isDynamic(expectedOffset)
1965  ? std::string("dynamic")
1966  : std::to_string(expectedOffset))
1967  << " instead of " << resultOffset;
1968 
1969  // Match strides in result memref type and in static_strides attribute.
1970  for (auto [idx, resultStride, expectedStride] :
1971  llvm::enumerate(resultStrides, getStaticStrides())) {
1972  if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1973  return emitError("expected result type with stride = ")
1974  << (ShapedType::isDynamic(expectedStride)
1975  ? std::string("dynamic")
1976  : std::to_string(expectedStride))
1977  << " instead of " << resultStride << " in dim = " << idx;
1978  }
1979 
1980  return success();
1981 }
1982 
1983 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1984  Value src = getSource();
1985  auto getPrevSrc = [&]() -> Value {
1986  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1987  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1988  return prev.getSource();
1989 
1990  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1991  if (auto prev = src.getDefiningOp<CastOp>())
1992  return prev.getSource();
1993 
1994  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1995  // are 0.
1996  if (auto prev = src.getDefiningOp<SubViewOp>())
1997  if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
1998  return prev.getSource();
1999 
2000  return nullptr;
2001  };
2002 
2003  if (auto prevSrc = getPrevSrc()) {
2004  getSourceMutable().assign(prevSrc);
2005  return getResult();
2006  }
2007 
2008  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
2009  if (ShapedType::isStaticShape(getType().getShape()) &&
2010  src.getType() == getType() && getStaticOffsets().front() == 0) {
2011  return src;
2012  }
2013 
2014  return nullptr;
2015 }
2016 
2017 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2019  constifyIndexValues(values, getType().getShape());
2020  return values;
2021 }
2022 
2023 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2024  SmallVector<OpFoldResult> values = getMixedStrides();
2025  SmallVector<int64_t> staticValues;
2026  int64_t unused;
2027  LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
2028  (void)status;
2029  assert(succeeded(status) && "could not get strides from type");
2030  constifyIndexValues(values, staticValues);
2031  return values;
2032 }
2033 
2034 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2035  SmallVector<OpFoldResult> values = getMixedOffsets();
2036  assert(values.size() == 1 &&
2037  "reinterpret_cast must have one and only one offset");
2038  SmallVector<int64_t> staticValues, unused;
2039  int64_t offset;
2040  LogicalResult status = getType().getStridesAndOffset(unused, offset);
2041  (void)status;
2042  assert(succeeded(status) && "could not get offset from type");
2043  staticValues.push_back(offset);
2044  constifyIndexValues(values, staticValues);
2045  return values[0];
2046 }
2047 
2048 namespace {
2049 /// Replace the sequence:
2050 /// ```
2051 /// base, offset, sizes, strides = extract_strided_metadata src
2052 /// dst = reinterpret_cast base to offset, sizes, strides
2053 /// ```
2054 /// With
2055 ///
2056 /// ```
2057 /// dst = memref.cast src
2058 /// ```
2059 ///
2060 /// Note: The cast operation is only inserted when the type of dst and src
2061 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
2062 ///
2063 /// This pattern also matches when the offset, sizes, and strides don't come
2064 /// directly from the `extract_strided_metadata`'s results but it can be
2065 /// statically proven that they would hold the same values.
2066 ///
2067 /// For instance, the following sequence would be replaced:
2068 /// ```
2069 /// base, offset, sizes, strides =
2070 /// extract_strided_metadata memref : memref<3x4xty>
2071 /// dst = reinterpret_cast base to 0, [3, 4], strides
2072 /// ```
2073 /// Because we know (thanks to the type of the input memref) that variable
2074 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
2075 ///
2076 /// Similarly, the following sequence would be replaced:
2077 /// ```
2078 /// c0 = arith.constant 0
2079 /// c4 = arith.constant 4
2080 /// base, offset, sizes, strides =
2081 /// extract_strided_metadata memref : memref<3x4xty>
2082 /// dst = reinterpret_cast base to c0, [3, c4], strides
2083 /// ```
2084 /// Because we know that `offset`and `c0` will hold 0
2085 /// and `c4` will hold 4.
2086 ///
2087 /// If the pattern above does not match, the input of the
2088 /// extract_strided_metadata is always folded into the input of the
2089 /// reinterpret_cast operator. This allows for dead code elimination to get rid
2090 /// of the extract_strided_metadata in some cases.
2091 struct ReinterpretCastOpExtractStridedMetadataFolder
2092  : public OpRewritePattern<ReinterpretCastOp> {
2093 public:
2095 
2096  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2097  PatternRewriter &rewriter) const override {
2098  auto extractStridedMetadata =
2099  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2100  if (!extractStridedMetadata)
2101  return failure();
2102 
2103  // Check if the reinterpret cast reconstructs a memref with the exact same
2104  // properties as the extract strided metadata.
2105  auto isReinterpretCastNoop = [&]() -> bool {
2106  // First, check that the strides are the same.
2107  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2108  op.getConstifiedMixedStrides()))
2109  return false;
2110 
2111  // Second, check the sizes.
2112  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2113  op.getConstifiedMixedSizes()))
2114  return false;
2115 
2116  // Finally, check the offset.
2117  assert(op.getMixedOffsets().size() == 1 &&
2118  "reinterpret_cast with more than one offset should have been "
2119  "rejected by the verifier");
2120  return extractStridedMetadata.getConstifiedMixedOffset() ==
2121  op.getConstifiedMixedOffset();
2122  };
2123 
2124  if (!isReinterpretCastNoop()) {
2125  // If the extract_strided_metadata / reinterpret_cast pair can't be
2126  // completely folded, then we could fold the input of the
2127  // extract_strided_metadata into the input of the reinterpret_cast
2128  // input. For some cases (e.g., static dimensions) the
2129  // the extract_strided_metadata is eliminated by dead code elimination.
2130  //
2131  // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2132  //
2133  // We can always fold the input of a extract_strided_metadata operator
2134  // to the input of a reinterpret_cast operator, because they point to
2135  // the same memory. Note that the reinterpret_cast does not use the
2136  // layout of its input memref, only its base memory pointer which is
2137  // the same as the base pointer returned by the extract_strided_metadata
2138  // operator and the base pointer of the extract_strided_metadata memref
2139  // input.
2140  rewriter.modifyOpInPlace(op, [&]() {
2141  op.getSourceMutable().assign(extractStridedMetadata.getSource());
2142  });
2143  return success();
2144  }
2145 
2146  // At this point, we know that the back and forth between extract strided
2147  // metadata and reinterpret cast is a noop. However, the final type of the
2148  // reinterpret cast may not be exactly the same as the original memref.
2149  // E.g., it could be changing a dimension from static to dynamic. Check that
2150  // here and add a cast if necessary.
2151  Type srcTy = extractStridedMetadata.getSource().getType();
2152  if (srcTy == op.getResult().getType())
2153  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2154  else
2155  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2156  extractStridedMetadata.getSource());
2157 
2158  return success();
2159  }
2160 };
2161 } // namespace
2162 
2163 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2164  MLIRContext *context) {
2165  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2166 }
2167 
2168 FailureOr<std::optional<SmallVector<Value>>>
2169 ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2170  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2171 }
2172 
2173 //===----------------------------------------------------------------------===//
2174 // Reassociative reshape ops
2175 //===----------------------------------------------------------------------===//
2176 
2177 void CollapseShapeOp::getAsmResultNames(
2178  function_ref<void(Value, StringRef)> setNameFn) {
2179  setNameFn(getResult(), "collapse_shape");
2180 }
2181 
2182 void ExpandShapeOp::getAsmResultNames(
2183  function_ref<void(Value, StringRef)> setNameFn) {
2184  setNameFn(getResult(), "expand_shape");
2185 }
2186 
2187 LogicalResult ExpandShapeOp::reifyResultShapes(
2188  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2189  reifiedResultShapes = {
2190  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2191  return success();
2192 }
2193 
2194 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2195 /// result and operand. Layout maps are verified separately.
2196 ///
2197 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2198 /// allowed in a reassocation group.
2199 static LogicalResult
2201  ArrayRef<int64_t> expandedShape,
2202  ArrayRef<ReassociationIndices> reassociation,
2203  bool allowMultipleDynamicDimsPerGroup) {
2204  // There must be one reassociation group per collapsed dimension.
2205  if (collapsedShape.size() != reassociation.size())
2206  return op->emitOpError("invalid number of reassociation groups: found ")
2207  << reassociation.size() << ", expected " << collapsedShape.size();
2208 
2209  // The next expected expanded dimension index (while iterating over
2210  // reassociation indices).
2211  int64_t nextDim = 0;
2212  for (const auto &it : llvm::enumerate(reassociation)) {
2213  ReassociationIndices group = it.value();
2214  int64_t collapsedDim = it.index();
2215 
2216  bool foundDynamic = false;
2217  for (int64_t expandedDim : group) {
2218  if (expandedDim != nextDim++)
2219  return op->emitOpError("reassociation indices must be contiguous");
2220 
2221  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2222  return op->emitOpError("reassociation index ")
2223  << expandedDim << " is out of bounds";
2224 
2225  // Check if there are multiple dynamic dims in a reassociation group.
2226  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2227  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2228  return op->emitOpError(
2229  "at most one dimension in a reassociation group may be dynamic");
2230  foundDynamic = true;
2231  }
2232  }
2233 
2234  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2235  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2236  return op->emitOpError("collapsed dim (")
2237  << collapsedDim
2238  << ") must be dynamic if and only if reassociation group is "
2239  "dynamic";
2240 
2241  // If all dims in the reassociation group are static, the size of the
2242  // collapsed dim can be verified.
2243  if (!foundDynamic) {
2244  int64_t groupSize = 1;
2245  for (int64_t expandedDim : group)
2246  groupSize *= expandedShape[expandedDim];
2247  if (groupSize != collapsedShape[collapsedDim])
2248  return op->emitOpError("collapsed dim size (")
2249  << collapsedShape[collapsedDim]
2250  << ") must equal reassociation group size (" << groupSize << ")";
2251  }
2252  }
2253 
2254  if (collapsedShape.empty()) {
2255  // Rank 0: All expanded dimensions must be 1.
2256  for (int64_t d : expandedShape)
2257  if (d != 1)
2258  return op->emitOpError(
2259  "rank 0 memrefs can only be extended/collapsed with/from ones");
2260  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2261  // Rank >= 1: Number of dimensions among all reassociation groups must match
2262  // the result memref rank.
2263  return op->emitOpError("expanded rank (")
2264  << expandedShape.size()
2265  << ") inconsistent with number of reassociation indices (" << nextDim
2266  << ")";
2267  }
2268 
2269  return success();
2270 }
2271 
2272 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2273  return getSymbolLessAffineMaps(getReassociationExprs());
2274 }
2275 
2276 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2278  getReassociationIndices());
2279 }
2280 
2281 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2282  return getSymbolLessAffineMaps(getReassociationExprs());
2283 }
2284 
2285 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2287  getReassociationIndices());
2288 }
2289 
2290 /// Compute the layout map after expanding a given source MemRef type with the
2291 /// specified reassociation indices.
2292 static FailureOr<StridedLayoutAttr>
2293 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2294  ArrayRef<ReassociationIndices> reassociation) {
2295  int64_t srcOffset;
2296  SmallVector<int64_t> srcStrides;
2297  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2298  return failure();
2299  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2300 
2301  // 1-1 mapping between srcStrides and reassociation packs.
2302  // Each srcStride starts with the given value and gets expanded according to
2303  // the proper entries in resultShape.
2304  // Example:
2305  // srcStrides = [10000, 1 , 100 ],
2306  // reassociations = [ [0], [1], [2, 3, 4]],
2307  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2308  // -> For the purpose of stride calculation, the useful sizes are:
2309  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2310  // resultStrides = [10000, 1, 600, 200, 100]
2311  // Note that a stride does not get expanded along the first entry of each
2312  // shape pack.
2313  SmallVector<int64_t> reverseResultStrides;
2314  reverseResultStrides.reserve(resultShape.size());
2315  unsigned shapeIndex = resultShape.size() - 1;
2316  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2317  ReassociationIndices reassoc = std::get<0>(it);
2318  int64_t currentStrideToExpand = std::get<1>(it);
2319  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2320  reverseResultStrides.push_back(currentStrideToExpand);
2321  currentStrideToExpand =
2322  (SaturatedInteger::wrap(currentStrideToExpand) *
2323  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2324  .asInteger();
2325  }
2326  }
2327  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2328  resultStrides.resize(resultShape.size(), 1);
2329  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2330 }
2331 
2332 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2333  MemRefType srcType, ArrayRef<int64_t> resultShape,
2334  ArrayRef<ReassociationIndices> reassociation) {
2335  if (srcType.getLayout().isIdentity()) {
2336  // If the source is contiguous (i.e., no layout map specified), so is the
2337  // result.
2338  MemRefLayoutAttrInterface layout;
2339  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2340  srcType.getMemorySpace());
2341  }
2342 
2343  // Source may not be contiguous. Compute the layout map.
2344  FailureOr<StridedLayoutAttr> computedLayout =
2345  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2346  if (failed(computedLayout))
2347  return failure();
2348  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2349  srcType.getMemorySpace());
2350 }
2351 
2352 FailureOr<SmallVector<OpFoldResult>>
2353 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2354  MemRefType expandedType,
2355  ArrayRef<ReassociationIndices> reassociation,
2356  ArrayRef<OpFoldResult> inputShape) {
2357  std::optional<SmallVector<OpFoldResult>> outputShape =
2358  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2359  inputShape);
2360  if (!outputShape)
2361  return failure();
2362  return *outputShape;
2363 }
2364 
2365 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2366  Type resultType, Value src,
2367  ArrayRef<ReassociationIndices> reassociation,
2368  ArrayRef<OpFoldResult> outputShape) {
2369  auto [staticOutputShape, dynamicOutputShape] =
2371  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2372  getReassociationIndicesAttribute(builder, reassociation),
2373  dynamicOutputShape, staticOutputShape);
2374 }
2375 
2376 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2377  Type resultType, Value src,
2378  ArrayRef<ReassociationIndices> reassociation) {
2379  SmallVector<OpFoldResult> inputShape =
2380  getMixedSizes(builder, result.location, src);
2381  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2382  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2383  builder, result.location, memrefResultTy, reassociation, inputShape);
2384  // Failure of this assertion usually indicates presence of multiple
2385  // dynamic dimensions in the same reassociation group.
2386  assert(succeeded(outputShape) && "unable to infer output shape");
2387  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2388 }
2389 
2390 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2391  ArrayRef<int64_t> resultShape, Value src,
2392  ArrayRef<ReassociationIndices> reassociation) {
2393  // Only ranked memref source values are supported.
2394  auto srcType = llvm::cast<MemRefType>(src.getType());
2395  FailureOr<MemRefType> resultType =
2396  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2397  // Failure of this assertion usually indicates a problem with the source
2398  // type, e.g., could not get strides/offset.
2399  assert(succeeded(resultType) && "could not compute layout");
2400  build(builder, result, *resultType, src, reassociation);
2401 }
2402 
2403 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2404  ArrayRef<int64_t> resultShape, Value src,
2405  ArrayRef<ReassociationIndices> reassociation,
2406  ArrayRef<OpFoldResult> outputShape) {
2407  // Only ranked memref source values are supported.
2408  auto srcType = llvm::cast<MemRefType>(src.getType());
2409  FailureOr<MemRefType> resultType =
2410  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2411  // Failure of this assertion usually indicates a problem with the source
2412  // type, e.g., could not get strides/offset.
2413  assert(succeeded(resultType) && "could not compute layout");
2414  build(builder, result, *resultType, src, reassociation, outputShape);
2415 }
2416 
2417 LogicalResult ExpandShapeOp::verify() {
2418  MemRefType srcType = getSrcType();
2419  MemRefType resultType = getResultType();
2420 
2421  if (srcType.getRank() > resultType.getRank()) {
2422  auto r0 = srcType.getRank();
2423  auto r1 = resultType.getRank();
2424  return emitOpError("has source rank ")
2425  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2426  << r0 << " > " << r1 << ").";
2427  }
2428 
2429  // Verify result shape.
2430  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2431  resultType.getShape(),
2432  getReassociationIndices(),
2433  /*allowMultipleDynamicDimsPerGroup=*/true)))
2434  return failure();
2435 
2436  // Compute expected result type (including layout map).
2437  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2438  srcType, resultType.getShape(), getReassociationIndices());
2439  if (failed(expectedResultType))
2440  return emitOpError("invalid source layout map");
2441 
2442  // Check actual result type.
2443  if (*expectedResultType != resultType)
2444  return emitOpError("expected expanded type to be ")
2445  << *expectedResultType << " but found " << resultType;
2446 
2447  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2448  return emitOpError("expected number of static shape bounds to be equal to "
2449  "the output rank (")
2450  << resultType.getRank() << ") but found "
2451  << getStaticOutputShape().size() << " inputs instead";
2452 
2453  if ((int64_t)getOutputShape().size() !=
2454  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2455  return emitOpError("mismatch in dynamic dims in output_shape and "
2456  "static_output_shape: static_output_shape has ")
2457  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2458  << " dynamic dims while output_shape has " << getOutputShape().size()
2459  << " values";
2460 
2461  // Verify if provided output shapes are in agreement with output type.
2462  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2463  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2464  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2465  if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2466  return emitOpError("invalid output shape provided at pos ") << pos;
2467  }
2468  }
2469 
2470  return success();
2471 }
2472 
2473 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2474  MLIRContext *context) {
2475  results.add<
2478 }
2479 
2480 FailureOr<std::optional<SmallVector<Value>>>
2481 ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2482  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2483 }
2484 
2485 /// Compute the layout map after collapsing a given source MemRef type with the
2486 /// specified reassociation indices.
2487 ///
2488 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2489 /// not possible to check this by inspecting a MemRefType in the general case.
2490 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2491 /// be valid (and thus accepted by this function) unless `strict = true`.
2492 static FailureOr<StridedLayoutAttr>
2493 computeCollapsedLayoutMap(MemRefType srcType,
2494  ArrayRef<ReassociationIndices> reassociation,
2495  bool strict = false) {
2496  int64_t srcOffset;
2497  SmallVector<int64_t> srcStrides;
2498  auto srcShape = srcType.getShape();
2499  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2500  return failure();
2501 
2502  // The result stride of a reassociation group is the stride of the last entry
2503  // of the reassociation. (TODO: Should be the minimum stride in the
2504  // reassociation because strides are not necessarily sorted. E.g., when using
2505  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2506  // strides are meaningless and could have any arbitrary value.
2507  SmallVector<int64_t> resultStrides;
2508  resultStrides.reserve(reassociation.size());
2509  for (const ReassociationIndices &reassoc : reassociation) {
2510  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2511  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2512  ref = ref.drop_back();
2513  if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2514  resultStrides.push_back(srcStrides[ref.back()]);
2515  } else {
2516  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2517  // the corresponding stride may have to be skipped. (See above comment.)
2518  // Therefore, the result stride cannot be statically determined and must
2519  // be dynamic.
2520  resultStrides.push_back(ShapedType::kDynamic);
2521  }
2522  }
2523 
2524  // Validate that each reassociation group is contiguous.
2525  unsigned resultStrideIndex = resultStrides.size() - 1;
2526  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2527  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2528  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2529  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2530  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2531 
2532  // Both source and result stride must have the same static value. In that
2533  // case, we can be sure, that the dimensions are collapsible (because they
2534  // are contiguous).
2535  // If `strict = false` (default during op verification), we accept cases
2536  // where one or both strides are dynamic. This is best effort: We reject
2537  // ops where obviously non-contiguous dims are collapsed, but accept ops
2538  // where we cannot be sure statically. Such ops may fail at runtime. See
2539  // the op documentation for details.
2540  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2541  if (strict && (stride.saturated || srcStride.saturated))
2542  return failure();
2543 
2544  // Dimensions of size 1 should be skipped, because their strides are
2545  // meaningless and could have any arbitrary value.
2546  if (srcShape[idx - 1] == 1)
2547  continue;
2548 
2549  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2550  return failure();
2551  }
2552  }
2553  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2554 }
2555 
2556 bool CollapseShapeOp::isGuaranteedCollapsible(
2557  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2558  // MemRefs with identity layout are always collapsible.
2559  if (srcType.getLayout().isIdentity())
2560  return true;
2561 
2562  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2563  /*strict=*/true));
2564 }
2565 
2566 MemRefType CollapseShapeOp::computeCollapsedType(
2567  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2568  SmallVector<int64_t> resultShape;
2569  resultShape.reserve(reassociation.size());
2570  for (const ReassociationIndices &group : reassociation) {
2571  auto groupSize = SaturatedInteger::wrap(1);
2572  for (int64_t srcDim : group)
2573  groupSize =
2574  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2575  resultShape.push_back(groupSize.asInteger());
2576  }
2577 
2578  if (srcType.getLayout().isIdentity()) {
2579  // If the source is contiguous (i.e., no layout map specified), so is the
2580  // result.
2581  MemRefLayoutAttrInterface layout;
2582  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2583  srcType.getMemorySpace());
2584  }
2585 
2586  // Source may not be fully contiguous. Compute the layout map.
2587  // Note: Dimensions that are collapsed into a single dim are assumed to be
2588  // contiguous.
2589  FailureOr<StridedLayoutAttr> computedLayout =
2590  computeCollapsedLayoutMap(srcType, reassociation);
2591  assert(succeeded(computedLayout) &&
2592  "invalid source layout map or collapsing non-contiguous dims");
2593  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2594  srcType.getMemorySpace());
2595 }
2596 
2597 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2598  ArrayRef<ReassociationIndices> reassociation,
2599  ArrayRef<NamedAttribute> attrs) {
2600  auto srcType = llvm::cast<MemRefType>(src.getType());
2601  MemRefType resultType =
2602  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2604  getReassociationIndicesAttribute(b, reassociation));
2605  build(b, result, resultType, src, attrs);
2606 }
2607 
2608 LogicalResult CollapseShapeOp::verify() {
2609  MemRefType srcType = getSrcType();
2610  MemRefType resultType = getResultType();
2611 
2612  if (srcType.getRank() < resultType.getRank()) {
2613  auto r0 = srcType.getRank();
2614  auto r1 = resultType.getRank();
2615  return emitOpError("has source rank ")
2616  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2617  << r0 << " < " << r1 << ").";
2618  }
2619 
2620  // Verify result shape.
2621  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2622  srcType.getShape(), getReassociationIndices(),
2623  /*allowMultipleDynamicDimsPerGroup=*/true)))
2624  return failure();
2625 
2626  // Compute expected result type (including layout map).
2627  MemRefType expectedResultType;
2628  if (srcType.getLayout().isIdentity()) {
2629  // If the source is contiguous (i.e., no layout map specified), so is the
2630  // result.
2631  MemRefLayoutAttrInterface layout;
2632  expectedResultType =
2633  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2634  srcType.getMemorySpace());
2635  } else {
2636  // Source may not be fully contiguous. Compute the layout map.
2637  // Note: Dimensions that are collapsed into a single dim are assumed to be
2638  // contiguous.
2639  FailureOr<StridedLayoutAttr> computedLayout =
2640  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2641  if (failed(computedLayout))
2642  return emitOpError(
2643  "invalid source layout map or collapsing non-contiguous dims");
2644  expectedResultType =
2645  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2646  *computedLayout, srcType.getMemorySpace());
2647  }
2648 
2649  if (expectedResultType != resultType)
2650  return emitOpError("expected collapsed type to be ")
2651  << expectedResultType << " but found " << resultType;
2652 
2653  return success();
2654 }
2655 
2657  : public OpRewritePattern<CollapseShapeOp> {
2658 public:
2660 
2661  LogicalResult matchAndRewrite(CollapseShapeOp op,
2662  PatternRewriter &rewriter) const override {
2663  auto cast = op.getOperand().getDefiningOp<CastOp>();
2664  if (!cast)
2665  return failure();
2666 
2667  if (!CastOp::canFoldIntoConsumerOp(cast))
2668  return failure();
2669 
2670  Type newResultType = CollapseShapeOp::computeCollapsedType(
2671  llvm::cast<MemRefType>(cast.getOperand().getType()),
2672  op.getReassociationIndices());
2673 
2674  if (newResultType == op.getResultType()) {
2675  rewriter.modifyOpInPlace(
2676  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2677  } else {
2678  Value newOp =
2679  CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2680  op.getReassociationIndices());
2681  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2682  }
2683  return success();
2684  }
2685 };
2686 
2687 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2688  MLIRContext *context) {
2689  results.add<
2691  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2692  memref::DimOp, MemRefType>,
2694 }
2695 
2696 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2697  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2698  adaptor.getOperands());
2699 }
2700 
2701 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2702  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2703  adaptor.getOperands());
2704 }
2705 
2706 FailureOr<std::optional<SmallVector<Value>>>
2707 CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2708  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2709 }
2710 
2711 //===----------------------------------------------------------------------===//
2712 // ReshapeOp
2713 //===----------------------------------------------------------------------===//
2714 
2715 void ReshapeOp::getAsmResultNames(
2716  function_ref<void(Value, StringRef)> setNameFn) {
2717  setNameFn(getResult(), "reshape");
2718 }
2719 
2720 LogicalResult ReshapeOp::verify() {
2721  Type operandType = getSource().getType();
2722  Type resultType = getResult().getType();
2723 
2724  Type operandElementType =
2725  llvm::cast<ShapedType>(operandType).getElementType();
2726  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2727  if (operandElementType != resultElementType)
2728  return emitOpError("element types of source and destination memref "
2729  "types should be the same");
2730 
2731  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2732  if (!operandMemRefType.getLayout().isIdentity())
2733  return emitOpError("source memref type should have identity affine map");
2734 
2735  int64_t shapeSize =
2736  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2737  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2738  if (resultMemRefType) {
2739  if (!resultMemRefType.getLayout().isIdentity())
2740  return emitOpError("result memref type should have identity affine map");
2741  if (shapeSize == ShapedType::kDynamic)
2742  return emitOpError("cannot use shape operand with dynamic length to "
2743  "reshape to statically-ranked memref type");
2744  if (shapeSize != resultMemRefType.getRank())
2745  return emitOpError(
2746  "length of shape operand differs from the result's memref rank");
2747  }
2748  return success();
2749 }
2750 
2751 FailureOr<std::optional<SmallVector<Value>>>
2752 ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
2753  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2754 }
2755 
2756 //===----------------------------------------------------------------------===//
2757 // StoreOp
2758 //===----------------------------------------------------------------------===//
2759 
2760 LogicalResult StoreOp::verify() {
2761  if (getNumOperands() != 2 + getMemRefType().getRank())
2762  return emitOpError("store index operand count not equal to memref rank");
2763 
2764  return success();
2765 }
2766 
2767 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2768  SmallVectorImpl<OpFoldResult> &results) {
2769  /// store(memrefcast) -> store
2770  return foldMemRefCast(*this, getValueToStore());
2771 }
2772 
2773 FailureOr<std::optional<SmallVector<Value>>>
2774 StoreOp::bubbleDownCasts(OpBuilder &builder) {
2775  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
2776  ValueRange());
2777 }
2778 
2779 //===----------------------------------------------------------------------===//
2780 // SubViewOp
2781 //===----------------------------------------------------------------------===//
2782 
2783 void SubViewOp::getAsmResultNames(
2784  function_ref<void(Value, StringRef)> setNameFn) {
2785  setNameFn(getResult(), "subview");
2786 }
2787 
2788 /// A subview result type can be fully inferred from the source type and the
2789 /// static representation of offsets, sizes and strides. Special sentinels
2790 /// encode the dynamic case.
2791 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2792  ArrayRef<int64_t> staticOffsets,
2793  ArrayRef<int64_t> staticSizes,
2794  ArrayRef<int64_t> staticStrides) {
2795  unsigned rank = sourceMemRefType.getRank();
2796  (void)rank;
2797  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2798  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2799  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2800 
2801  // Extract source offset and strides.
2802  auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2803 
2804  // Compute target offset whose value is:
2805  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2806  int64_t targetOffset = sourceOffset;
2807  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2808  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2809  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2810  SaturatedInteger::wrap(staticOffset) *
2811  SaturatedInteger::wrap(sourceStride))
2812  .asInteger();
2813  }
2814 
2815  // Compute target stride whose value is:
2816  // `sourceStrides_i * staticStrides_i`.
2817  SmallVector<int64_t, 4> targetStrides;
2818  targetStrides.reserve(staticOffsets.size());
2819  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2820  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2821  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2822  SaturatedInteger::wrap(staticStride))
2823  .asInteger());
2824  }
2825 
2826  // The type is now known.
2827  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2828  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2829  targetOffset, targetStrides),
2830  sourceMemRefType.getMemorySpace());
2831 }
2832 
2833 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2834  ArrayRef<OpFoldResult> offsets,
2835  ArrayRef<OpFoldResult> sizes,
2836  ArrayRef<OpFoldResult> strides) {
2837  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2838  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2839  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2840  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2841  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2842  if (!hasValidSizesOffsets(staticOffsets))
2843  return {};
2844  if (!hasValidSizesOffsets(staticSizes))
2845  return {};
2846  if (!hasValidStrides(staticStrides))
2847  return {};
2848  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2849  staticSizes, staticStrides);
2850 }
2851 
2852 MemRefType SubViewOp::inferRankReducedResultType(
2853  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2854  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2855  ArrayRef<int64_t> strides) {
2856  MemRefType inferredType =
2857  inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2858  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2859  "expected ");
2860  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2861  return inferredType;
2862 
2863  // Compute which dimensions are dropped.
2864  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2865  computeRankReductionMask(inferredType.getShape(), resultShape);
2866  assert(dimsToProject.has_value() && "invalid rank reduction");
2867 
2868  // Compute the layout and result type.
2869  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2870  SmallVector<int64_t> rankReducedStrides;
2871  rankReducedStrides.reserve(resultShape.size());
2872  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2873  if (!dimsToProject->contains(idx))
2874  rankReducedStrides.push_back(value);
2875  }
2876  return MemRefType::get(resultShape, inferredType.getElementType(),
2877  StridedLayoutAttr::get(inferredLayout.getContext(),
2878  inferredLayout.getOffset(),
2879  rankReducedStrides),
2880  inferredType.getMemorySpace());
2881 }
2882 
2883 MemRefType SubViewOp::inferRankReducedResultType(
2884  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2886  ArrayRef<OpFoldResult> strides) {
2887  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2888  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2889  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2890  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2891  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2892  return SubViewOp::inferRankReducedResultType(
2893  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2894  staticStrides);
2895 }
2896 
2897 // Build a SubViewOp with mixed static and dynamic entries and custom result
2898 // type. If the type passed is nullptr, it is inferred.
2899 void SubViewOp::build(OpBuilder &b, OperationState &result,
2900  MemRefType resultType, Value source,
2901  ArrayRef<OpFoldResult> offsets,
2902  ArrayRef<OpFoldResult> sizes,
2903  ArrayRef<OpFoldResult> strides,
2904  ArrayRef<NamedAttribute> attrs) {
2905  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2906  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2907  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2908  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2909  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2910  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2911  // Structuring implementation this way avoids duplication between builders.
2912  if (!resultType) {
2913  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2914  staticSizes, staticStrides);
2915  }
2916  result.addAttributes(attrs);
2917  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2918  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2919  b.getDenseI64ArrayAttr(staticSizes),
2920  b.getDenseI64ArrayAttr(staticStrides));
2921 }
2922 
2923 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2924 // type.
2925 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2926  ArrayRef<OpFoldResult> offsets,
2927  ArrayRef<OpFoldResult> sizes,
2928  ArrayRef<OpFoldResult> strides,
2929  ArrayRef<NamedAttribute> attrs) {
2930  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2931 }
2932 
2933 // Build a SubViewOp with static entries and inferred result type.
2934 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2935  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2936  ArrayRef<int64_t> strides,
2937  ArrayRef<NamedAttribute> attrs) {
2938  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2939  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2940  return b.getI64IntegerAttr(v);
2941  }));
2942  SmallVector<OpFoldResult> sizeValues =
2943  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2944  return b.getI64IntegerAttr(v);
2945  }));
2946  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2947  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2948  return b.getI64IntegerAttr(v);
2949  }));
2950  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2951 }
2952 
2953 // Build a SubViewOp with dynamic entries and custom result type. If the
2954 // type passed is nullptr, it is inferred.
2955 void SubViewOp::build(OpBuilder &b, OperationState &result,
2956  MemRefType resultType, Value source,
2957  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2958  ArrayRef<int64_t> strides,
2959  ArrayRef<NamedAttribute> attrs) {
2960  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2961  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2962  return b.getI64IntegerAttr(v);
2963  }));
2964  SmallVector<OpFoldResult> sizeValues =
2965  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2966  return b.getI64IntegerAttr(v);
2967  }));
2968  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2969  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2970  return b.getI64IntegerAttr(v);
2971  }));
2972  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2973  attrs);
2974 }
2975 
2976 // Build a SubViewOp with dynamic entries and custom result type. If the type
2977 // passed is nullptr, it is inferred.
2978 void SubViewOp::build(OpBuilder &b, OperationState &result,
2979  MemRefType resultType, Value source, ValueRange offsets,
2980  ValueRange sizes, ValueRange strides,
2981  ArrayRef<NamedAttribute> attrs) {
2982  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2983  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2984  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2985  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2986  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2987  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2988  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2989 }
2990 
2991 // Build a SubViewOp with dynamic entries and inferred result type.
2992 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2993  ValueRange offsets, ValueRange sizes, ValueRange strides,
2994  ArrayRef<NamedAttribute> attrs) {
2995  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2996 }
2997 
2998 /// For ViewLikeOpInterface.
2999 Value SubViewOp::getViewSource() { return getSource(); }
3000 
3001 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
3002 /// static value).
3003 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
3004  int64_t t1Offset, t2Offset;
3005  SmallVector<int64_t> t1Strides, t2Strides;
3006  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3007  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3008  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3009 }
3010 
3011 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
3012 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
3013 /// marked as dropped in `droppedDims`.
3014 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
3015  const llvm::SmallBitVector &droppedDims) {
3016  assert(size_t(t1.getRank()) == droppedDims.size() &&
3017  "incorrect number of bits");
3018  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3019  "incorrect number of dropped dims");
3020  int64_t t1Offset, t2Offset;
3021  SmallVector<int64_t> t1Strides, t2Strides;
3022  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3023  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3024  if (failed(res1) || failed(res2))
3025  return false;
3026  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
3027  if (droppedDims[i])
3028  continue;
3029  if (t1Strides[i] != t2Strides[j])
3030  return false;
3031  ++j;
3032  }
3033  return true;
3034 }
3035 
3037  SubViewOp op, Type expectedType) {
3038  auto memrefType = llvm::cast<ShapedType>(expectedType);
3039  switch (result) {
3041  return success();
3043  return op->emitError("expected result rank to be smaller or equal to ")
3044  << "the source rank, but got " << op.getType();
3046  return op->emitError("expected result type to be ")
3047  << expectedType
3048  << " or a rank-reduced version. (mismatch of result sizes), but got "
3049  << op.getType();
3051  return op->emitError("expected result element type to be ")
3052  << memrefType.getElementType() << ", but got " << op.getType();
3054  return op->emitError(
3055  "expected result and source memory spaces to match, but got ")
3056  << op.getType();
3058  return op->emitError("expected result type to be ")
3059  << expectedType
3060  << " or a rank-reduced version. (mismatch of result layout), but "
3061  "got "
3062  << op.getType();
3063  }
3064  llvm_unreachable("unexpected subview verification result");
3065 }
3066 
3067 /// Verifier for SubViewOp.
3068 LogicalResult SubViewOp::verify() {
3069  MemRefType baseType = getSourceType();
3070  MemRefType subViewType = getType();
3071  ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3072  ArrayRef<int64_t> staticSizes = getStaticSizes();
3073  ArrayRef<int64_t> staticStrides = getStaticStrides();
3074 
3075  // The base memref and the view memref should be in the same memory space.
3076  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3077  return emitError("different memory spaces specified for base memref "
3078  "type ")
3079  << baseType << " and subview memref type " << subViewType;
3080 
3081  // Verify that the base memref type has a strided layout map.
3082  if (!baseType.isStrided())
3083  return emitError("base type ") << baseType << " is not strided";
3084 
3085  // Compute the expected result type, assuming that there are no rank
3086  // reductions.
3087  MemRefType expectedType = SubViewOp::inferResultType(
3088  baseType, staticOffsets, staticSizes, staticStrides);
3089 
3090  // Verify all properties of a shaped type: rank, element type and dimension
3091  // sizes. This takes into account potential rank reductions.
3092  auto shapedTypeVerification = isRankReducedType(
3093  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
3094  if (shapedTypeVerification != SliceVerificationResult::Success)
3095  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
3096 
3097  // Make sure that the memory space did not change.
3098  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3100  *this, expectedType);
3101 
3102  // Verify the offset of the layout map.
3103  if (!haveCompatibleOffsets(expectedType, subViewType))
3105  *this, expectedType);
3106 
3107  // The only thing that's left to verify now are the strides. First, compute
3108  // the unused dimensions due to rank reductions. We have to look at sizes and
3109  // strides to decide which dimensions were dropped. This function also
3110  // partially verifies strides in case of rank reductions.
3111  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3112  getMixedSizes());
3113  if (failed(unusedDims))
3115  *this, expectedType);
3116 
3117  // Strides must match.
3118  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3120  *this, expectedType);
3121 
3122  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3123  // to the base memref.
3124  SliceBoundsVerificationResult boundsResult =
3125  verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3126  staticStrides, /*generateErrorMessage=*/true);
3127  if (!boundsResult.isValid)
3128  return getOperation()->emitError(boundsResult.errorMessage);
3129 
3130  return success();
3131 }
3132 
3133 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
3134  return os << "range " << range.offset << ":" << range.size << ":"
3135  << range.stride;
3136 }
3137 
3138 /// Return the list of Range (i.e. offset, size, stride). Each Range
3139 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3140 /// with `b` at location `loc`.
3141 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3142  OpBuilder &b, Location loc) {
3143  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3144  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3145  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3147  unsigned rank = ranks[0];
3148  res.reserve(rank);
3149  for (unsigned idx = 0; idx < rank; ++idx) {
3150  Value offset =
3151  op.isDynamicOffset(idx)
3152  ? op.getDynamicOffset(idx)
3153  : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
3154  Value size =
3155  op.isDynamicSize(idx)
3156  ? op.getDynamicSize(idx)
3157  : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
3158  Value stride =
3159  op.isDynamicStride(idx)
3160  ? op.getDynamicStride(idx)
3161  : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
3162  res.emplace_back(Range{offset, size, stride});
3163  }
3164  return res;
3165 }
3166 
3167 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3168 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3169 /// the rank of the inferred result type if `currentResultType` is lower rank
3170 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3171 /// together with the result type. In this case, it is important to compute
3172 /// the dropped dimensions using `currentSourceType` whose strides align with
3173 /// `currentResultType`.
3175  MemRefType currentResultType, MemRefType currentSourceType,
3176  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3177  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3178  MemRefType nonRankReducedType = SubViewOp::inferResultType(
3179  sourceType, mixedOffsets, mixedSizes, mixedStrides);
3180  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3181  currentSourceType, currentResultType, mixedSizes);
3182  if (failed(unusedDims))
3183  return nullptr;
3184 
3185  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3186  SmallVector<int64_t> shape, strides;
3187  unsigned numDimsAfterReduction =
3188  nonRankReducedType.getRank() - unusedDims->count();
3189  shape.reserve(numDimsAfterReduction);
3190  strides.reserve(numDimsAfterReduction);
3191  for (const auto &[idx, size, stride] :
3192  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3193  nonRankReducedType.getShape(), layout.getStrides())) {
3194  if (unusedDims->test(idx))
3195  continue;
3196  shape.push_back(size);
3197  strides.push_back(stride);
3198  }
3199 
3200  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3201  StridedLayoutAttr::get(sourceType.getContext(),
3202  layout.getOffset(), strides),
3203  nonRankReducedType.getMemorySpace());
3204 }
3205 
3207  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3208  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3209  unsigned rank = memrefType.getRank();
3210  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3211  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3212  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3213  MemRefType targetType = SubViewOp::inferRankReducedResultType(
3214  targetShape, memrefType, offsets, sizes, strides);
3215  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3216  sizes, strides);
3217 }
3218 
3219 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3220  Value value,
3221  ArrayRef<int64_t> desiredShape) {
3222  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3223  assert(sourceMemrefType && "not a ranked memref type");
3224  auto sourceShape = sourceMemrefType.getShape();
3225  if (sourceShape.equals(desiredShape))
3226  return value;
3227  auto maybeRankReductionMask =
3228  mlir::computeRankReductionMask(sourceShape, desiredShape);
3229  if (!maybeRankReductionMask)
3230  return failure();
3231  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3232 }
3233 
3234 /// Helper method to check if a `subview` operation is trivially a no-op. This
3235 /// is the case if the all offsets are zero, all strides are 1, and the source
3236 /// shape is same as the size of the subview. In such cases, the subview can
3237 /// be folded into its source.
3238 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3239  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3240  return false;
3241 
3242  auto mixedOffsets = subViewOp.getMixedOffsets();
3243  auto mixedSizes = subViewOp.getMixedSizes();
3244  auto mixedStrides = subViewOp.getMixedStrides();
3245 
3246  // Check offsets are zero.
3247  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3248  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3249  return !intValue || intValue.value() != 0;
3250  }))
3251  return false;
3252 
3253  // Check strides are one.
3254  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3255  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3256  return !intValue || intValue.value() != 1;
3257  }))
3258  return false;
3259 
3260  // Check all size values are static and matches the (static) source shape.
3261  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3262  for (const auto &size : llvm::enumerate(mixedSizes)) {
3263  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3264  if (!intValue || *intValue != sourceShape[size.index()])
3265  return false;
3266  }
3267  // All conditions met. The `SubViewOp` is foldable as a no-op.
3268  return true;
3269 }
3270 
3271 namespace {
3272 /// Pattern to rewrite a subview op with MemRefCast arguments.
3273 /// This essentially pushes memref.cast past its consuming subview when
3274 /// `canFoldIntoConsumerOp` is true.
3275 ///
3276 /// Example:
3277 /// ```
3278 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3279 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3280 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3281 /// ```
3282 /// is rewritten into:
3283 /// ```
3284 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3285 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3286 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3287 /// ```
3288 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3289 public:
3291 
3292  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3293  PatternRewriter &rewriter) const override {
3294  // Any constant operand, just return to let SubViewOpConstantFolder kick
3295  // in.
3296  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3297  return matchPattern(operand, matchConstantIndex());
3298  }))
3299  return failure();
3300 
3301  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3302  if (!castOp)
3303  return failure();
3304 
3305  if (!CastOp::canFoldIntoConsumerOp(castOp))
3306  return failure();
3307 
3308  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3309  // the MemRefCastOp source operand type to infer the result type and the
3310  // current SubViewOp source operand type to compute the dropped dimensions
3311  // if the operation is rank-reducing.
3312  auto resultType = getCanonicalSubViewResultType(
3313  subViewOp.getType(), subViewOp.getSourceType(),
3314  llvm::cast<MemRefType>(castOp.getSource().getType()),
3315  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3316  subViewOp.getMixedStrides());
3317  if (!resultType)
3318  return failure();
3319 
3320  Value newSubView = SubViewOp::create(
3321  rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3322  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3323  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3324  subViewOp.getStaticStrides());
3325  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3326  newSubView);
3327  return success();
3328  }
3329 };
3330 
3331 /// Canonicalize subview ops that are no-ops. When the source shape is not
3332 /// same as a result shape due to use of `affine_map`.
3333 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3334 public:
3336 
3337  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3338  PatternRewriter &rewriter) const override {
3339  if (!isTrivialSubViewOp(subViewOp))
3340  return failure();
3341  if (subViewOp.getSourceType() == subViewOp.getType()) {
3342  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3343  return success();
3344  }
3345  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3346  subViewOp.getSource());
3347  return success();
3348  }
3349 };
3350 } // namespace
3351 
3352 /// Return the canonical type of the result of a subview.
3354  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3355  ArrayRef<OpFoldResult> mixedSizes,
3356  ArrayRef<OpFoldResult> mixedStrides) {
3357  // Infer a memref type without taking into account any rank reductions.
3358  MemRefType resTy = SubViewOp::inferResultType(
3359  op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3360  if (!resTy)
3361  return {};
3362  MemRefType nonReducedType = resTy;
3363 
3364  // Directly return the non-rank reduced type if there are no dropped dims.
3365  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3366  if (droppedDims.none())
3367  return nonReducedType;
3368 
3369  // Take the strides and offset from the non-rank reduced type.
3370  auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3371 
3372  // Drop dims from shape and strides.
3373  SmallVector<int64_t> targetShape;
3374  SmallVector<int64_t> targetStrides;
3375  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3376  if (droppedDims.test(i))
3377  continue;
3378  targetStrides.push_back(nonReducedStrides[i]);
3379  targetShape.push_back(nonReducedType.getDimSize(i));
3380  }
3381 
3382  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3383  StridedLayoutAttr::get(nonReducedType.getContext(),
3384  offset, targetStrides),
3385  nonReducedType.getMemorySpace());
3386  }
3387 };
3388 
3389 /// A canonicalizer wrapper to replace SubViewOps.
3391  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3392  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3393  }
3394 };
3395 
3396 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3397  MLIRContext *context) {
3398  results
3401  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3402 }
3403 
3404 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3405  MemRefType sourceMemrefType = getSource().getType();
3406  MemRefType resultMemrefType = getResult().getType();
3407  auto resultLayout =
3408  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3409 
3410  if (resultMemrefType == sourceMemrefType &&
3411  resultMemrefType.hasStaticShape() &&
3412  (!resultLayout || resultLayout.hasStaticLayout())) {
3413  return getViewSource();
3414  }
3415 
3416  // Fold subview(subview(x)), where both subviews have the same size and the
3417  // second subview's offsets are all zero. (I.e., the second subview is a
3418  // no-op.)
3419  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3420  auto srcSizes = srcSubview.getMixedSizes();
3421  auto sizes = getMixedSizes();
3422  auto offsets = getMixedOffsets();
3423  bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3424  auto strides = getMixedStrides();
3425  bool allStridesOne = llvm::all_of(strides, isOneInteger);
3426  bool allSizesSame = llvm::equal(sizes, srcSizes);
3427  if (allOffsetsZero && allStridesOne && allSizesSame &&
3428  resultMemrefType == sourceMemrefType)
3429  return getViewSource();
3430  }
3431 
3432  return {};
3433 }
3434 
3435 FailureOr<std::optional<SmallVector<Value>>>
3436 SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3437  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3438 }
3439 
3440 //===----------------------------------------------------------------------===//
3441 // TransposeOp
3442 //===----------------------------------------------------------------------===//
3443 
3444 void TransposeOp::getAsmResultNames(
3445  function_ref<void(Value, StringRef)> setNameFn) {
3446  setNameFn(getResult(), "transpose");
3447 }
3448 
3449 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3450 static MemRefType inferTransposeResultType(MemRefType memRefType,
3451  AffineMap permutationMap) {
3452  auto originalSizes = memRefType.getShape();
3453  auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3454  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3455 
3456  // Compute permuted sizes and strides.
3457  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3458  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3459 
3460  return MemRefType::Builder(memRefType)
3461  .setShape(sizes)
3462  .setLayout(
3463  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3464 }
3465 
3466 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3467  AffineMapAttr permutation,
3468  ArrayRef<NamedAttribute> attrs) {
3469  auto permutationMap = permutation.getValue();
3470  assert(permutationMap);
3471 
3472  auto memRefType = llvm::cast<MemRefType>(in.getType());
3473  // Compute result type.
3474  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3475 
3476  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3477  build(b, result, resultType, in, attrs);
3478 }
3479 
3480 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3482  p << " " << getIn() << " " << getPermutation();
3483  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3484  p << " : " << getIn().getType() << " to " << getType();
3485 }
3486 
3487 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3489  AffineMap permutation;
3490  MemRefType srcType, dstType;
3491  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3492  parser.parseOptionalAttrDict(result.attributes) ||
3493  parser.parseColonType(srcType) ||
3494  parser.resolveOperand(in, srcType, result.operands) ||
3495  parser.parseKeywordType("to", dstType) ||
3496  parser.addTypeToList(dstType, result.types))
3497  return failure();
3498 
3499  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3500  AffineMapAttr::get(permutation));
3501  return success();
3502 }
3503 
3504 LogicalResult TransposeOp::verify() {
3505  if (!getPermutation().isPermutation())
3506  return emitOpError("expected a permutation map");
3507  if (getPermutation().getNumDims() != getIn().getType().getRank())
3508  return emitOpError("expected a permutation map of same rank as the input");
3509 
3510  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3511  auto resultType = llvm::cast<MemRefType>(getType());
3512  auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3513  .canonicalizeStridedLayout();
3514 
3515  if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3516  return emitOpError("result type ")
3517  << resultType
3518  << " is not equivalent to the canonical transposed input type "
3519  << canonicalResultType;
3520  return success();
3521 }
3522 
3523 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3524  // First check for identity permutation, we can fold it away if input and
3525  // result types are identical already.
3526  if (getPermutation().isIdentity() && getType() == getIn().getType())
3527  return getIn();
3528  // Fold two consecutive memref.transpose Ops into one by composing their
3529  // permutation maps.
3530  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3531  AffineMap composedPermutation =
3532  getPermutation().compose(otherTransposeOp.getPermutation());
3533  getInMutable().assign(otherTransposeOp.getIn());
3534  setPermutation(composedPermutation);
3535  return getResult();
3536  }
3537  return {};
3538 }
3539 
3540 FailureOr<std::optional<SmallVector<Value>>>
3541 TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3542  return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
3543 }
3544 
3545 //===----------------------------------------------------------------------===//
3546 // ViewOp
3547 //===----------------------------------------------------------------------===//
3548 
3549 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3550  setNameFn(getResult(), "view");
3551 }
3552 
3553 LogicalResult ViewOp::verify() {
3554  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3555  auto viewType = getType();
3556 
3557  // The base memref should have identity layout map (or none).
3558  if (!baseType.getLayout().isIdentity())
3559  return emitError("unsupported map for base memref type ") << baseType;
3560 
3561  // The result memref should have identity layout map (or none).
3562  if (!viewType.getLayout().isIdentity())
3563  return emitError("unsupported map for result memref type ") << viewType;
3564 
3565  // The base memref and the view memref should be in the same memory space.
3566  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3567  return emitError("different memory spaces specified for base memref "
3568  "type ")
3569  << baseType << " and view memref type " << viewType;
3570 
3571  // Verify that we have the correct number of sizes for the result type.
3572  unsigned numDynamicDims = viewType.getNumDynamicDims();
3573  if (getSizes().size() != numDynamicDims)
3574  return emitError("incorrect number of size operands for type ") << viewType;
3575 
3576  return success();
3577 }
3578 
3579 Value ViewOp::getViewSource() { return getSource(); }
3580 
3581 OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3582  MemRefType sourceMemrefType = getSource().getType();
3583  MemRefType resultMemrefType = getResult().getType();
3584 
3585  if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3586  return getViewSource();
3587 
3588  return {};
3589 }
3590 
3591 namespace {
3592 
3593 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3595 
3596  LogicalResult matchAndRewrite(ViewOp viewOp,
3597  PatternRewriter &rewriter) const override {
3598  // Return if none of the operands are constants.
3599  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3600  return matchPattern(operand, matchConstantIndex());
3601  }))
3602  return failure();
3603 
3604  // Get result memref type.
3605  auto memrefType = viewOp.getType();
3606 
3607  // Get offset from old memref view type 'memRefType'.
3608  int64_t oldOffset;
3609  SmallVector<int64_t, 4> oldStrides;
3610  if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3611  return failure();
3612  assert(oldOffset == 0 && "Expected 0 offset");
3613 
3614  SmallVector<Value, 4> newOperands;
3615 
3616  // Offset cannot be folded into result type.
3617 
3618  // Fold any dynamic dim operands which are produced by a constant.
3619  SmallVector<int64_t, 4> newShapeConstants;
3620  newShapeConstants.reserve(memrefType.getRank());
3621 
3622  unsigned dynamicDimPos = 0;
3623  unsigned rank = memrefType.getRank();
3624  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3625  int64_t dimSize = memrefType.getDimSize(dim);
3626  // If this is already static dimension, keep it.
3627  if (ShapedType::isStatic(dimSize)) {
3628  newShapeConstants.push_back(dimSize);
3629  continue;
3630  }
3631  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3632  if (auto constantIndexOp =
3633  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3634  // Dynamic shape dimension will be folded.
3635  newShapeConstants.push_back(constantIndexOp.value());
3636  } else {
3637  // Dynamic shape dimension not folded; copy operand from old memref.
3638  newShapeConstants.push_back(dimSize);
3639  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3640  }
3641  dynamicDimPos++;
3642  }
3643 
3644  // Create new memref type with constant folded dims.
3645  MemRefType newMemRefType =
3646  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3647  // Nothing new, don't fold.
3648  if (newMemRefType == memrefType)
3649  return failure();
3650 
3651  // Create new ViewOp.
3652  auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
3653  viewOp.getOperand(0), viewOp.getByteShift(),
3654  newOperands);
3655  // Insert a cast so we have the same type as the old memref type.
3656  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3657  return success();
3658  }
3659 };
3660 
3661 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3663 
3664  LogicalResult matchAndRewrite(ViewOp viewOp,
3665  PatternRewriter &rewriter) const override {
3666  Value memrefOperand = viewOp.getOperand(0);
3667  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3668  if (!memrefCastOp)
3669  return failure();
3670  Value allocOperand = memrefCastOp.getOperand();
3671  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3672  if (!allocOp)
3673  return failure();
3674  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3675  viewOp.getByteShift(),
3676  viewOp.getSizes());
3677  return success();
3678  }
3679 };
3680 
3681 } // namespace
3682 
3683 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3684  MLIRContext *context) {
3685  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3686 }
3687 
3688 FailureOr<std::optional<SmallVector<Value>>>
3689 ViewOp::bubbleDownCasts(OpBuilder &builder) {
3690  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3691 }
3692 
3693 //===----------------------------------------------------------------------===//
3694 // AtomicRMWOp
3695 //===----------------------------------------------------------------------===//
3696 
3697 LogicalResult AtomicRMWOp::verify() {
3698  if (getMemRefType().getRank() != getNumOperands() - 2)
3699  return emitOpError(
3700  "expects the number of subscripts to be equal to memref rank");
3701  switch (getKind()) {
3702  case arith::AtomicRMWKind::addf:
3703  case arith::AtomicRMWKind::maximumf:
3704  case arith::AtomicRMWKind::minimumf:
3705  case arith::AtomicRMWKind::mulf:
3706  if (!llvm::isa<FloatType>(getValue().getType()))
3707  return emitOpError() << "with kind '"
3708  << arith::stringifyAtomicRMWKind(getKind())
3709  << "' expects a floating-point type";
3710  break;
3711  case arith::AtomicRMWKind::addi:
3712  case arith::AtomicRMWKind::maxs:
3713  case arith::AtomicRMWKind::maxu:
3714  case arith::AtomicRMWKind::mins:
3715  case arith::AtomicRMWKind::minu:
3716  case arith::AtomicRMWKind::muli:
3717  case arith::AtomicRMWKind::ori:
3718  case arith::AtomicRMWKind::xori:
3719  case arith::AtomicRMWKind::andi:
3720  if (!llvm::isa<IntegerType>(getValue().getType()))
3721  return emitOpError() << "with kind '"
3722  << arith::stringifyAtomicRMWKind(getKind())
3723  << "' expects an integer type";
3724  break;
3725  default:
3726  break;
3727  }
3728  return success();
3729 }
3730 
3731 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3732  /// atomicrmw(memrefcast) -> atomicrmw
3733  if (succeeded(foldMemRefCast(*this, getValue())))
3734  return getResult();
3735  return OpFoldResult();
3736 }
3737 
3738 FailureOr<std::optional<SmallVector<Value>>>
3739 AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
3740  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
3741  getResult());
3742 }
3743 
3744 //===----------------------------------------------------------------------===//
3745 // TableGen'd op method definitions
3746 //===----------------------------------------------------------------------===//
3747 
3748 #define GET_OP_CLASSES
3749 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static bool hasSideEffects(Operation *op)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:51
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition: MemRefOps.cpp:96
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1563
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
Definition: MemRefOps.cpp:117
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2200
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:436
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3450
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:417
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:3003
static LogicalResult FoldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
Definition: MemRefOps.cpp:853
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1400
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
Definition: MemRefOps.cpp:3036
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3174
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1577
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:940
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3238
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:459
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
Definition: MemRefOps.cpp:148
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:3014
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:925
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2293
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2493
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:188
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:140
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition: Block.cpp:250
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:167
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:100
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:193
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:226
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:68
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3206
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3141
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:23
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:507
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:510
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:467
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:470
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2661
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3390
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3391
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3353
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3354
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.