MLIR  19.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 /// Materialize a single constant operation from a given attribute value with
30 /// the desired resultant type.
32  Attribute value, Type type,
33  Location loc) {
34  return arith::ConstantOp::materialize(builder, value, type, loc);
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Common canonicalization pattern support logic
39 //===----------------------------------------------------------------------===//
40 
41 /// This is a common class used for patterns of the form
42 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
43 /// into the root operation directly.
44 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
45  bool folded = false;
46  for (OpOperand &operand : op->getOpOperands()) {
47  auto cast = operand.get().getDefiningOp<CastOp>();
48  if (cast && operand.get() != inner &&
49  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50  operand.set(cast.getOperand());
51  folded = true;
52  }
53  }
54  return success(folded);
55 }
56 
57 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
58 /// type.
60  if (auto memref = llvm::dyn_cast<MemRefType>(type))
61  return RankedTensorType::get(memref.getShape(), memref.getElementType());
62  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
63  return UnrankedTensorType::get(memref.getElementType());
64  return NoneType::get(type.getContext());
65 }
66 
68  int64_t dim) {
69  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that infers the constant values from a list of \p values,
91 /// a \p memRefTy, and another helper function \p getAttributes.
92 /// The inferred constant values replace the related `OpFoldResult` in
93 /// \p values.
94 ///
95 /// \note This function shouldn't be used directly, instead, use the
96 /// `getConstifiedMixedXXX` methods from the related operations.
97 ///
98 /// \p getAttributes retuns a list of potentially constant values, as determined
99 /// by \p isDynamic, from the given \p memRefTy. The returned list must have as
100 /// many elements as \p values or be empty.
101 ///
102 /// E.g., consider the following example:
103 /// ```
104 /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
105 /// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
106 /// ```
107 /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
108 /// Now using this helper function with:
109 /// - `values == [2, %dyn_stride]`,
110 /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
111 /// - `getAttributes == getConstantStrides` (i.e., a wrapper around
112 /// `getStridesAndOffset`), and
113 /// - `isDynamic == ShapedType::isDynamic`
114 /// Will yield: `values == [2, 1]`
116  SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
117  MLIRContext *ctxt,
118  llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
119  llvm::function_ref<bool(int64_t)> isDynamic) {
120  SmallVector<int64_t> constValues = getAttributes(memRefTy);
121  Builder builder(ctxt);
122  for (const auto &it : llvm::enumerate(constValues)) {
123  int64_t constValue = it.value();
124  if (!isDynamic(constValue))
125  values[it.index()] = builder.getIndexAttr(constValue);
126  }
127  for (OpFoldResult &ofr : values) {
128  if (ofr.is<Attribute>()) {
129  // FIXME: We shouldn't need to do that, but right now, the static indices
130  // are created with the wrong type: `i64` instead of `index`.
131  // As a result, if we were to keep the attribute as is, we may fail to see
132  // that two attributes are equal because one would have the i64 type and
133  // the other the index type.
134  // The alternative would be to create constant indices with getI64Attr in
135  // this and the previous loop, but it doesn't logically make sense (we are
136  // dealing with indices here) and would only strenghten the inconsistency
137  // around how static indices are created (some places use getI64Attr,
138  // others use getIndexAttr).
139  // The workaround here is to stick to the IndexAttr type for all the
140  // values, hence we recreate the attribute even when it is already static
141  // to make sure the type is consistent.
142  ofr = builder.getIndexAttr(
143  llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt());
144  continue;
145  }
146  std::optional<int64_t> maybeConstant =
147  getConstantIntValue(ofr.get<Value>());
148  if (maybeConstant)
149  ofr = builder.getIndexAttr(*maybeConstant);
150  }
151 }
152 
153 /// Wrapper around `getShape` that conforms to the function signature
154 /// expected for `getAttributes` in `constifyIndexValues`.
155 static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
156  ArrayRef<int64_t> sizes = memRefTy.getShape();
157  return SmallVector<int64_t>(sizes.begin(), sizes.end());
158 }
159 
160 /// Wrapper around `getStridesAndOffset` that returns only the offset and
161 /// conforms to the function signature expected for `getAttributes` in
162 /// `constifyIndexValues`.
163 static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
164  SmallVector<int64_t> strides;
165  int64_t offset;
166  LogicalResult hasStaticInformation =
167  getStridesAndOffset(memrefType, strides, offset);
168  if (failed(hasStaticInformation))
169  return SmallVector<int64_t>();
170  return SmallVector<int64_t>(1, offset);
171 }
172 
173 /// Wrapper around `getStridesAndOffset` that returns only the strides and
174 /// conforms to the function signature expected for `getAttributes` in
175 /// `constifyIndexValues`.
176 static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
177  SmallVector<int64_t> strides;
178  int64_t offset;
179  LogicalResult hasStaticInformation =
180  getStridesAndOffset(memrefType, strides, offset);
181  if (failed(hasStaticInformation))
182  return SmallVector<int64_t>();
183  return strides;
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // AllocOp / AllocaOp
188 //===----------------------------------------------------------------------===//
189 
190 void AllocOp::getAsmResultNames(
191  function_ref<void(Value, StringRef)> setNameFn) {
192  setNameFn(getResult(), "alloc");
193 }
194 
195 void AllocaOp::getAsmResultNames(
196  function_ref<void(Value, StringRef)> setNameFn) {
197  setNameFn(getResult(), "alloca");
198 }
199 
200 template <typename AllocLikeOp>
201 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
202  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
203  "applies to only alloc or alloca");
204  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
205  if (!memRefType)
206  return op.emitOpError("result must be a memref");
207 
208  if (static_cast<int64_t>(op.getDynamicSizes().size()) !=
209  memRefType.getNumDynamicDims())
210  return op.emitOpError("dimension operand count does not equal memref "
211  "dynamic dimension count");
212 
213  unsigned numSymbols = 0;
214  if (!memRefType.getLayout().isIdentity())
215  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
216  if (op.getSymbolOperands().size() != numSymbols)
217  return op.emitOpError("symbol operand count does not equal memref symbol "
218  "count: expected ")
219  << numSymbols << ", got " << op.getSymbolOperands().size();
220 
221  return success();
222 }
223 
224 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
225 
226 LogicalResult AllocaOp::verify() {
227  // An alloca op needs to have an ancestor with an allocation scope trait.
228  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
229  return emitOpError(
230  "requires an ancestor op with AutomaticAllocationScope trait");
231 
232  return verifyAllocLikeOp(*this);
233 }
234 
235 namespace {
236 /// Fold constant dimensions into an alloc like operation.
237 template <typename AllocLikeOp>
238 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
240 
241  LogicalResult matchAndRewrite(AllocLikeOp alloc,
242  PatternRewriter &rewriter) const override {
243  // Check to see if any dimensions operands are constants. If so, we can
244  // substitute and drop them.
245  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
246  APInt constSizeArg;
247  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
248  return false;
249  return constSizeArg.isNonNegative();
250  }))
251  return failure();
252 
253  auto memrefType = alloc.getType();
254 
255  // Ok, we have one or more constant operands. Collect the non-constant ones
256  // and keep track of the resultant memref type to build.
257  SmallVector<int64_t, 4> newShapeConstants;
258  newShapeConstants.reserve(memrefType.getRank());
259  SmallVector<Value, 4> dynamicSizes;
260 
261  unsigned dynamicDimPos = 0;
262  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
263  int64_t dimSize = memrefType.getDimSize(dim);
264  // If this is already static dimension, keep it.
265  if (!ShapedType::isDynamic(dimSize)) {
266  newShapeConstants.push_back(dimSize);
267  continue;
268  }
269  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
270  APInt constSizeArg;
271  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
272  constSizeArg.isNonNegative()) {
273  // Dynamic shape dimension will be folded.
274  newShapeConstants.push_back(constSizeArg.getZExtValue());
275  } else {
276  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
277  newShapeConstants.push_back(ShapedType::kDynamic);
278  dynamicSizes.push_back(dynamicSize);
279  }
280  dynamicDimPos++;
281  }
282 
283  // Create new memref type (which will have fewer dynamic dimensions).
284  MemRefType newMemRefType =
285  MemRefType::Builder(memrefType).setShape(newShapeConstants);
286  assert(static_cast<int64_t>(dynamicSizes.size()) ==
287  newMemRefType.getNumDynamicDims());
288 
289  // Create and insert the alloc op for the new memref.
290  auto newAlloc = rewriter.create<AllocLikeOp>(
291  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
292  alloc.getAlignmentAttr());
293  // Insert a cast so we have the same type as the old alloc.
294  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
295  return success();
296  }
297 };
298 
299 /// Fold alloc operations with no users or only store and dealloc uses.
300 template <typename T>
301 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
303 
304  LogicalResult matchAndRewrite(T alloc,
305  PatternRewriter &rewriter) const override {
306  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
307  if (auto storeOp = dyn_cast<StoreOp>(op))
308  return storeOp.getValue() == alloc;
309  return !isa<DeallocOp>(op);
310  }))
311  return failure();
312 
313  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
314  rewriter.eraseOp(user);
315 
316  rewriter.eraseOp(alloc);
317  return success();
318  }
319 };
320 } // namespace
321 
322 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
323  MLIRContext *context) {
324  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
325 }
326 
327 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
328  MLIRContext *context) {
329  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
330  context);
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // ReallocOp
335 //===----------------------------------------------------------------------===//
336 
337 LogicalResult ReallocOp::verify() {
338  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
339  MemRefType resultType = getType();
340 
341  // The source memref should have identity layout (or none).
342  if (!sourceType.getLayout().isIdentity())
343  return emitError("unsupported layout for source memref type ")
344  << sourceType;
345 
346  // The result memref should have identity layout (or none).
347  if (!resultType.getLayout().isIdentity())
348  return emitError("unsupported layout for result memref type ")
349  << resultType;
350 
351  // The source memref and the result memref should be in the same memory space.
352  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
353  return emitError("different memory spaces specified for source memref "
354  "type ")
355  << sourceType << " and result memref type " << resultType;
356 
357  // The source memref and the result memref should have the same element type.
358  if (sourceType.getElementType() != resultType.getElementType())
359  return emitError("different element types specified for source memref "
360  "type ")
361  << sourceType << " and result memref type " << resultType;
362 
363  // Verify that we have the dynamic dimension operand when it is needed.
364  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
365  return emitError("missing dimension operand for result type ")
366  << resultType;
367  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
368  return emitError("unnecessary dimension operand for result type ")
369  << resultType;
370 
371  return success();
372 }
373 
374 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
375  MLIRContext *context) {
376  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // AllocaScopeOp
381 //===----------------------------------------------------------------------===//
382 
384  bool printBlockTerminators = false;
385 
386  p << ' ';
387  if (!getResults().empty()) {
388  p << " -> (" << getResultTypes() << ")";
389  printBlockTerminators = true;
390  }
391  p << ' ';
392  p.printRegion(getBodyRegion(),
393  /*printEntryBlockArgs=*/false,
394  /*printBlockTerminators=*/printBlockTerminators);
395  p.printOptionalAttrDict((*this)->getAttrs());
396 }
397 
398 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
399  // Create a region for the body.
400  result.regions.reserve(1);
401  Region *bodyRegion = result.addRegion();
402 
403  // Parse optional results type list.
404  if (parser.parseOptionalArrowTypeList(result.types))
405  return failure();
406 
407  // Parse the body region.
408  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
409  return failure();
410  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
411  result.location);
412 
413  // Parse the optional attribute list.
414  if (parser.parseOptionalAttrDict(result.attributes))
415  return failure();
416 
417  return success();
418 }
419 
420 void AllocaScopeOp::getSuccessorRegions(
422  if (!point.isParent()) {
423  regions.push_back(RegionSuccessor(getResults()));
424  return;
425  }
426 
427  regions.push_back(RegionSuccessor(&getBodyRegion()));
428 }
429 
430 /// Given an operation, return whether this op is guaranteed to
431 /// allocate an AutomaticAllocationScopeResource
433  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
434  if (!interface)
435  return false;
436  for (auto res : op->getResults()) {
437  if (auto effect =
438  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
439  if (isa<SideEffects::AutomaticAllocationScopeResource>(
440  effect->getResource()))
441  return true;
442  }
443  }
444  return false;
445 }
446 
447 /// Given an operation, return whether this op itself could
448 /// allocate an AutomaticAllocationScopeResource. Note that
449 /// this will not check whether an operation contained within
450 /// the op can allocate.
452  // This op itself doesn't create a stack allocation,
453  // the inner allocation should be handled separately.
455  return false;
456  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
457  if (!interface)
458  return true;
459  for (auto res : op->getResults()) {
460  if (auto effect =
461  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
462  if (isa<SideEffects::AutomaticAllocationScopeResource>(
463  effect->getResource()))
464  return true;
465  }
466  }
467  return false;
468 }
469 
470 /// Return whether this op is the last non terminating op
471 /// in a region. That is to say, it is in a one-block region
472 /// and is only followed by a terminator. This prevents
473 /// extending the lifetime of allocations.
475  return op->getNextNode() == op->getBlock()->getTerminator() &&
476  op->getParentRegion()->getBlocks().size() == 1;
477 }
478 
479 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
480 /// or it contains no allocation.
481 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
483 
484  LogicalResult matchAndRewrite(AllocaScopeOp op,
485  PatternRewriter &rewriter) const override {
486  bool hasPotentialAlloca =
487  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
488  if (alloc == op)
489  return WalkResult::advance();
491  return WalkResult::interrupt();
492  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
493  return WalkResult::skip();
494  return WalkResult::advance();
495  }).wasInterrupted();
496 
497  // If this contains no potential allocation, it is always legal to
498  // inline. Otherwise, consider two conditions:
499  if (hasPotentialAlloca) {
500  // If the parent isn't an allocation scope, or we are not the last
501  // non-terminator op in the parent, we will extend the lifetime.
503  return failure();
504  if (!lastNonTerminatorInRegion(op))
505  return failure();
506  }
507 
508  Block *block = &op.getRegion().front();
509  Operation *terminator = block->getTerminator();
510  ValueRange results = terminator->getOperands();
511  rewriter.inlineBlockBefore(block, op);
512  rewriter.replaceOp(op, results);
513  rewriter.eraseOp(terminator);
514  return success();
515  }
516 };
517 
518 /// Move allocations into an allocation scope, if it is legal to
519 /// move them (e.g. their operands are available at the location
520 /// the op would be moved to).
521 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
523 
524  LogicalResult matchAndRewrite(AllocaScopeOp op,
525  PatternRewriter &rewriter) const override {
526 
528  return failure();
529 
530  Operation *lastParentWithoutScope = op->getParentOp();
531 
532  if (!lastParentWithoutScope ||
533  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
534  return failure();
535 
536  // Only apply to if this is this last non-terminator
537  // op in the block (lest lifetime be extended) of a one
538  // block region
539  if (!lastNonTerminatorInRegion(op) ||
540  !lastNonTerminatorInRegion(lastParentWithoutScope))
541  return failure();
542 
543  while (!lastParentWithoutScope->getParentOp()
545  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
546  if (!lastParentWithoutScope ||
547  !lastNonTerminatorInRegion(lastParentWithoutScope))
548  return failure();
549  }
550  assert(lastParentWithoutScope->getParentOp()
552 
553  Region *containingRegion = nullptr;
554  for (auto &r : lastParentWithoutScope->getRegions()) {
555  if (r.isAncestor(op->getParentRegion())) {
556  assert(containingRegion == nullptr &&
557  "only one region can contain the op");
558  containingRegion = &r;
559  }
560  }
561  assert(containingRegion && "op must be contained in a region");
562 
563  SmallVector<Operation *> toHoist;
564  op->walk([&](Operation *alloc) {
566  return WalkResult::skip();
567 
568  // If any operand is not defined before the location of
569  // lastParentWithoutScope (i.e. where we would hoist to), skip.
570  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
571  return containingRegion->isAncestor(v.getParentRegion());
572  }))
573  return WalkResult::skip();
574  toHoist.push_back(alloc);
575  return WalkResult::advance();
576  });
577 
578  if (toHoist.empty())
579  return failure();
580  rewriter.setInsertionPoint(lastParentWithoutScope);
581  for (auto *op : toHoist) {
582  auto *cloned = rewriter.clone(*op);
583  rewriter.replaceOp(op, cloned->getResults());
584  }
585  return success();
586  }
587 };
588 
589 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
590  MLIRContext *context) {
591  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // AssumeAlignmentOp
596 //===----------------------------------------------------------------------===//
597 
598 LogicalResult AssumeAlignmentOp::verify() {
599  if (!llvm::isPowerOf2_32(getAlignment()))
600  return emitOpError("alignment must be power of 2");
601  return success();
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // CastOp
606 //===----------------------------------------------------------------------===//
607 
608 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
609  setNameFn(getResult(), "cast");
610 }
611 
612 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
613 /// source memref. This is useful to fold a memref.cast into a consuming op
614 /// and implement canonicalization patterns for ops in different dialects that
615 /// may consume the results of memref.cast operations. Such foldable memref.cast
616 /// operations are typically inserted as `view` and `subview` ops are
617 /// canonicalized, to preserve the type compatibility of their uses.
618 ///
619 /// Returns true when all conditions are met:
620 /// 1. source and result are ranked memrefs with strided semantics and same
621 /// element type and rank.
622 /// 2. each of the source's size, offset or stride has more static information
623 /// than the corresponding result's size, offset or stride.
624 ///
625 /// Example 1:
626 /// ```mlir
627 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
628 /// %2 = consumer %1 ... : memref<?x?xf32> ...
629 /// ```
630 ///
631 /// may fold into:
632 ///
633 /// ```mlir
634 /// %2 = consumer %0 ... : memref<8x16xf32> ...
635 /// ```
636 ///
637 /// Example 2:
638 /// ```
639 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
640 /// to memref<?x?xf32>
641 /// consumer %1 : memref<?x?xf32> ...
642 /// ```
643 ///
644 /// may fold into:
645 ///
646 /// ```
647 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
648 /// ```
649 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
650  MemRefType sourceType =
651  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
652  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
653 
654  // Requires ranked MemRefType.
655  if (!sourceType || !resultType)
656  return false;
657 
658  // Requires same elemental type.
659  if (sourceType.getElementType() != resultType.getElementType())
660  return false;
661 
662  // Requires same rank.
663  if (sourceType.getRank() != resultType.getRank())
664  return false;
665 
666  // Only fold casts between strided memref forms.
667  int64_t sourceOffset, resultOffset;
668  SmallVector<int64_t, 4> sourceStrides, resultStrides;
669  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
670  failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
671  return false;
672 
673  // If cast is towards more static sizes along any dimension, don't fold.
674  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
675  auto ss = std::get<0>(it), st = std::get<1>(it);
676  if (ss != st)
677  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
678  return false;
679  }
680 
681  // If cast is towards more static offset along any dimension, don't fold.
682  if (sourceOffset != resultOffset)
683  if (ShapedType::isDynamic(sourceOffset) &&
684  !ShapedType::isDynamic(resultOffset))
685  return false;
686 
687  // If cast is towards more static strides along any dimension, don't fold.
688  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
689  auto ss = std::get<0>(it), st = std::get<1>(it);
690  if (ss != st)
691  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
692  return false;
693  }
694 
695  return true;
696 }
697 
698 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
699  if (inputs.size() != 1 || outputs.size() != 1)
700  return false;
701  Type a = inputs.front(), b = outputs.front();
702  auto aT = llvm::dyn_cast<MemRefType>(a);
703  auto bT = llvm::dyn_cast<MemRefType>(b);
704 
705  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
706  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
707 
708  if (aT && bT) {
709  if (aT.getElementType() != bT.getElementType())
710  return false;
711  if (aT.getLayout() != bT.getLayout()) {
712  int64_t aOffset, bOffset;
713  SmallVector<int64_t, 4> aStrides, bStrides;
714  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
715  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
716  aStrides.size() != bStrides.size())
717  return false;
718 
719  // Strides along a dimension/offset are compatible if the value in the
720  // source memref is static and the value in the target memref is the
721  // same. They are also compatible if either one is dynamic (see
722  // description of MemRefCastOp for details).
723  auto checkCompatible = [](int64_t a, int64_t b) {
724  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
725  };
726  if (!checkCompatible(aOffset, bOffset))
727  return false;
728  for (const auto &aStride : enumerate(aStrides))
729  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
730  return false;
731  }
732  if (aT.getMemorySpace() != bT.getMemorySpace())
733  return false;
734 
735  // They must have the same rank, and any specified dimensions must match.
736  if (aT.getRank() != bT.getRank())
737  return false;
738 
739  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
740  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
741  if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
742  aDim != bDim)
743  return false;
744  }
745  return true;
746  } else {
747  if (!aT && !uaT)
748  return false;
749  if (!bT && !ubT)
750  return false;
751  // Unranked to unranked casting is unsupported
752  if (uaT && ubT)
753  return false;
754 
755  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
756  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
757  if (aEltType != bEltType)
758  return false;
759 
760  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
761  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
762  return aMemSpace == bMemSpace;
763  }
764 
765  return false;
766 }
767 
768 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
769  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // CopyOp
774 //===----------------------------------------------------------------------===//
775 
776 namespace {
777 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
778 /// and element type, the cast can be skipped. Such CastOps only cast the layout
779 /// of the type.
780 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
782 
783  LogicalResult matchAndRewrite(CopyOp copyOp,
784  PatternRewriter &rewriter) const override {
785  bool modified = false;
786 
787  // Check source.
788  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
789  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
790  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
791 
792  if (fromType && toType) {
793  if (fromType.getShape() == toType.getShape() &&
794  fromType.getElementType() == toType.getElementType()) {
795  rewriter.modifyOpInPlace(copyOp, [&] {
796  copyOp.getSourceMutable().assign(castOp.getSource());
797  });
798  modified = true;
799  }
800  }
801  }
802 
803  // Check target.
804  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
805  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
806  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
807 
808  if (fromType && toType) {
809  if (fromType.getShape() == toType.getShape() &&
810  fromType.getElementType() == toType.getElementType()) {
811  rewriter.modifyOpInPlace(copyOp, [&] {
812  copyOp.getTargetMutable().assign(castOp.getSource());
813  });
814  modified = true;
815  }
816  }
817  }
818 
819  return success(modified);
820  }
821 };
822 
823 /// Fold memref.copy(%x, %x).
824 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
826 
827  LogicalResult matchAndRewrite(CopyOp copyOp,
828  PatternRewriter &rewriter) const override {
829  if (copyOp.getSource() != copyOp.getTarget())
830  return failure();
831 
832  rewriter.eraseOp(copyOp);
833  return success();
834  }
835 };
836 
837 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
839 
840  static bool isEmptyMemRef(BaseMemRefType type) {
841  return type.hasRank() &&
842  llvm::any_of(type.getShape(), [](int64_t x) { return x == 0; });
843  }
844 
845  LogicalResult matchAndRewrite(CopyOp copyOp,
846  PatternRewriter &rewriter) const override {
847  if (isEmptyMemRef(copyOp.getSource().getType()) ||
848  isEmptyMemRef(copyOp.getTarget().getType())) {
849  rewriter.eraseOp(copyOp);
850  return success();
851  }
852 
853  return failure();
854  }
855 };
856 } // namespace
857 
858 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
859  MLIRContext *context) {
860  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
861 }
862 
863 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
865  /// copy(memrefcast) -> copy
866  bool folded = false;
867  Operation *op = *this;
868  for (OpOperand &operand : op->getOpOperands()) {
869  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
870  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
871  operand.set(castOp.getOperand());
872  folded = true;
873  }
874  }
875  return success(folded);
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // DeallocOp
880 //===----------------------------------------------------------------------===//
881 
882 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
884  /// dealloc(memrefcast) -> dealloc
885  return foldMemRefCast(*this);
886 }
887 
888 //===----------------------------------------------------------------------===//
889 // DimOp
890 //===----------------------------------------------------------------------===//
891 
892 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
893  setNameFn(getResult(), "dim");
894 }
895 
896 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
897  int64_t index) {
898  auto loc = result.location;
899  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
900  build(builder, result, source, indexValue);
901 }
902 
903 std::optional<int64_t> DimOp::getConstantIndex() {
904  return getConstantIntValue(getIndex());
905 }
906 
907 Speculation::Speculatability DimOp::getSpeculatability() {
908  auto constantIndex = getConstantIndex();
909  if (!constantIndex)
911 
912  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
913  if (!rankedSourceType)
915 
916  if (rankedSourceType.getRank() <= constantIndex)
918 
920 }
921 
922 /// Return a map with key being elements in `vals` and data being number of
923 /// occurences of it. Use std::map, since the `vals` here are strides and the
924 /// dynamic stride value is the same as the tombstone value for
925 /// `DenseMap<int64_t>`.
926 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
927  std::map<int64_t, unsigned> numOccurences;
928  for (auto val : vals)
929  numOccurences[val]++;
930  return numOccurences;
931 }
932 
933 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
934 /// to be a subset of `originalType` with some `1` entries erased, return the
935 /// set of indices that specifies which of the entries of `originalShape` are
936 /// dropped to obtain `reducedShape`.
937 /// This accounts for cases where there are multiple unit-dims, but only a
938 /// subset of those are dropped. For MemRefTypes these can be disambiguated
939 /// using the strides. If a dimension is dropped the stride must be dropped too.
940 static FailureOr<llvm::SmallBitVector>
941 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
942  ArrayRef<OpFoldResult> sizes) {
943  llvm::SmallBitVector unusedDims(originalType.getRank());
944  if (originalType.getRank() == reducedType.getRank())
945  return unusedDims;
946 
947  for (const auto &dim : llvm::enumerate(sizes))
948  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
949  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
950  unusedDims.set(dim.index());
951 
952  // Early exit for the case where the number of unused dims matches the number
953  // of ranks reduced.
954  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
955  originalType.getRank())
956  return unusedDims;
957 
958  SmallVector<int64_t> originalStrides, candidateStrides;
959  int64_t originalOffset, candidateOffset;
960  if (failed(
961  getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
962  failed(
963  getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
964  return failure();
965 
966  // For memrefs, a dimension is truly dropped if its corresponding stride is
967  // also dropped. This is particularly important when more than one of the dims
968  // is 1. Track the number of occurences of the strides in the original type
969  // and the candidate type. For each unused dim that stride should not be
970  // present in the candidate type. Note that there could be multiple dimensions
971  // that have the same size. We dont need to exactly figure out which dim
972  // corresponds to which stride, we just need to verify that the number of
973  // reptitions of a stride in the original + number of unused dims with that
974  // stride == number of repititions of a stride in the candidate.
975  std::map<int64_t, unsigned> currUnaccountedStrides =
976  getNumOccurences(originalStrides);
977  std::map<int64_t, unsigned> candidateStridesNumOccurences =
978  getNumOccurences(candidateStrides);
979  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
980  if (!unusedDims.test(dim))
981  continue;
982  int64_t originalStride = originalStrides[dim];
983  if (currUnaccountedStrides[originalStride] >
984  candidateStridesNumOccurences[originalStride]) {
985  // This dim can be treated as dropped.
986  currUnaccountedStrides[originalStride]--;
987  continue;
988  }
989  if (currUnaccountedStrides[originalStride] ==
990  candidateStridesNumOccurences[originalStride]) {
991  // The stride for this is not dropped. Keep as is.
992  unusedDims.reset(dim);
993  continue;
994  }
995  if (currUnaccountedStrides[originalStride] <
996  candidateStridesNumOccurences[originalStride]) {
997  // This should never happen. Cant have a stride in the reduced rank type
998  // that wasnt in the original one.
999  return failure();
1000  }
1001  }
1002 
1003  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1004  originalType.getRank())
1005  return failure();
1006  return unusedDims;
1007 }
1008 
1009 llvm::SmallBitVector SubViewOp::getDroppedDims() {
1010  MemRefType sourceType = getSourceType();
1011  MemRefType resultType = getType();
1012  FailureOr<llvm::SmallBitVector> unusedDims =
1013  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1014  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1015  return *unusedDims;
1016 }
1017 
1018 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1019  // All forms of folding require a known index.
1020  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1021  if (!index)
1022  return {};
1023 
1024  // Folding for unranked types (UnrankedMemRefType) is not supported.
1025  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1026  if (!memrefType)
1027  return {};
1028 
1029  // Out of bound indices produce undefined behavior but are still valid IR.
1030  // Don't choke on them.
1031  int64_t indexVal = index.getInt();
1032  if (indexVal < 0 || indexVal >= memrefType.getRank())
1033  return {};
1034 
1035  // Fold if the shape extent along the given index is known.
1036  if (!memrefType.isDynamicDim(index.getInt())) {
1037  Builder builder(getContext());
1038  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1039  }
1040 
1041  // The size at the given index is now known to be a dynamic size.
1042  unsigned unsignedIndex = index.getValue().getZExtValue();
1043 
1044  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1045  Operation *definingOp = getSource().getDefiningOp();
1046 
1047  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1048  return *(alloc.getDynamicSizes().begin() +
1049  memrefType.getDynamicDimIndex(unsignedIndex));
1050 
1051  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1052  return *(alloca.getDynamicSizes().begin() +
1053  memrefType.getDynamicDimIndex(unsignedIndex));
1054 
1055  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1056  return *(view.getDynamicSizes().begin() +
1057  memrefType.getDynamicDimIndex(unsignedIndex));
1058 
1059  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1060  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1061  unsigned resultIndex = 0;
1062  unsigned sourceRank = subview.getSourceType().getRank();
1063  unsigned sourceIndex = 0;
1064  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1065  if (unusedDims.test(i))
1066  continue;
1067  if (resultIndex == unsignedIndex) {
1068  sourceIndex = i;
1069  break;
1070  }
1071  resultIndex++;
1072  }
1073  assert(subview.isDynamicSize(sourceIndex) &&
1074  "expected dynamic subview size");
1075  return subview.getDynamicSize(sourceIndex);
1076  }
1077 
1078  if (auto sizeInterface =
1079  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1080  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1081  "Expected dynamic subview size");
1082  return sizeInterface.getDynamicSize(unsignedIndex);
1083  }
1084 
1085  // dim(memrefcast) -> dim
1086  if (succeeded(foldMemRefCast(*this)))
1087  return getResult();
1088 
1089  return {};
1090 }
1091 
1092 namespace {
1093 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1094 /// operand.
1095 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1097 
1098  LogicalResult matchAndRewrite(DimOp dim,
1099  PatternRewriter &rewriter) const override {
1100  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1101 
1102  if (!reshape)
1103  return rewriter.notifyMatchFailure(
1104  dim, "Dim op is not defined by a reshape op.");
1105 
1106  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1107  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1108  // cheaply check that either of the following conditions hold:
1109  // 1. dim.getIndex() is defined in the same block as reshape but before
1110  // reshape.
1111  // 2. dim.getIndex() is defined in a parent block of
1112  // reshape.
1113 
1114  // Check condition 1
1115  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1116  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1117  if (reshape->isBeforeInBlock(definingOp)) {
1118  return rewriter.notifyMatchFailure(
1119  dim,
1120  "dim.getIndex is not defined before reshape in the same block.");
1121  }
1122  } // else dim.getIndex is a block argument to reshape->getBlock and
1123  // dominates reshape
1124  } // Check condition 2
1125  else if (dim->getBlock() != reshape->getBlock() &&
1126  !dim.getIndex().getParentRegion()->isProperAncestor(
1127  reshape->getParentRegion())) {
1128  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1129  // already know dim.getIndex() dominates reshape without calling
1130  // `isProperAncestor`
1131  return rewriter.notifyMatchFailure(
1132  dim, "dim.getIndex does not dominate reshape.");
1133  }
1134 
1135  // Place the load directly after the reshape to ensure that the shape memref
1136  // was not mutated.
1137  rewriter.setInsertionPointAfter(reshape);
1138  Location loc = dim.getLoc();
1139  Value load =
1140  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1141  if (load.getType() != dim.getType())
1142  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1143  rewriter.replaceOp(dim, load);
1144  return success();
1145  }
1146 };
1147 
1148 } // namespace
1149 
1150 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1151  MLIRContext *context) {
1152  results.add<DimOfMemRefReshape>(context);
1153 }
1154 
1155 // ---------------------------------------------------------------------------
1156 // DmaStartOp
1157 // ---------------------------------------------------------------------------
1158 
1159 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1160  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1161  ValueRange destIndices, Value numElements,
1162  Value tagMemRef, ValueRange tagIndices, Value stride,
1163  Value elementsPerStride) {
1164  result.addOperands(srcMemRef);
1165  result.addOperands(srcIndices);
1166  result.addOperands(destMemRef);
1167  result.addOperands(destIndices);
1168  result.addOperands({numElements, tagMemRef});
1169  result.addOperands(tagIndices);
1170  if (stride)
1171  result.addOperands({stride, elementsPerStride});
1172 }
1173 
1175  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1176  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1177  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1178  if (isStrided())
1179  p << ", " << getStride() << ", " << getNumElementsPerStride();
1180 
1181  p.printOptionalAttrDict((*this)->getAttrs());
1182  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1183  << ", " << getTagMemRef().getType();
1184 }
1185 
1186 // Parse DmaStartOp.
1187 // Ex:
1188 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1189 // %tag[%index], %stride, %num_elt_per_stride :
1190 // : memref<3076 x f32, 0>,
1191 // memref<1024 x f32, 2>,
1192 // memref<1 x i32>
1193 //
1194 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1195  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1197  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1199  OpAsmParser::UnresolvedOperand numElementsInfo;
1200  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1203 
1204  SmallVector<Type, 3> types;
1205  auto indexType = parser.getBuilder().getIndexType();
1206 
1207  // Parse and resolve the following list of operands:
1208  // *) source memref followed by its indices (in square brackets).
1209  // *) destination memref followed by its indices (in square brackets).
1210  // *) dma size in KiB.
1211  if (parser.parseOperand(srcMemRefInfo) ||
1212  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1213  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1214  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1215  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1216  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1217  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1218  return failure();
1219 
1220  // Parse optional stride and elements per stride.
1221  if (parser.parseTrailingOperandList(strideInfo))
1222  return failure();
1223 
1224  bool isStrided = strideInfo.size() == 2;
1225  if (!strideInfo.empty() && !isStrided) {
1226  return parser.emitError(parser.getNameLoc(),
1227  "expected two stride related operands");
1228  }
1229 
1230  if (parser.parseColonTypeList(types))
1231  return failure();
1232  if (types.size() != 3)
1233  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1234 
1235  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1236  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1237  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1238  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1239  // size should be an index.
1240  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1241  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1242  // tag indices should be index.
1243  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1244  return failure();
1245 
1246  if (isStrided) {
1247  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1248  return failure();
1249  }
1250 
1251  return success();
1252 }
1253 
1254 LogicalResult DmaStartOp::verify() {
1255  unsigned numOperands = getNumOperands();
1256 
1257  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1258  // the number of elements.
1259  if (numOperands < 4)
1260  return emitOpError("expected at least 4 operands");
1261 
1262  // Check types of operands. The order of these calls is important: the later
1263  // calls rely on some type properties to compute the operand position.
1264  // 1. Source memref.
1265  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1266  return emitOpError("expected source to be of memref type");
1267  if (numOperands < getSrcMemRefRank() + 4)
1268  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1269  << " operands";
1270  if (!getSrcIndices().empty() &&
1271  !llvm::all_of(getSrcIndices().getTypes(),
1272  [](Type t) { return t.isIndex(); }))
1273  return emitOpError("expected source indices to be of index type");
1274 
1275  // 2. Destination memref.
1276  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1277  return emitOpError("expected destination to be of memref type");
1278  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1279  if (numOperands < numExpectedOperands)
1280  return emitOpError() << "expected at least " << numExpectedOperands
1281  << " operands";
1282  if (!getDstIndices().empty() &&
1283  !llvm::all_of(getDstIndices().getTypes(),
1284  [](Type t) { return t.isIndex(); }))
1285  return emitOpError("expected destination indices to be of index type");
1286 
1287  // 3. Number of elements.
1288  if (!getNumElements().getType().isIndex())
1289  return emitOpError("expected num elements to be of index type");
1290 
1291  // 4. Tag memref.
1292  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1293  return emitOpError("expected tag to be of memref type");
1294  numExpectedOperands += getTagMemRefRank();
1295  if (numOperands < numExpectedOperands)
1296  return emitOpError() << "expected at least " << numExpectedOperands
1297  << " operands";
1298  if (!getTagIndices().empty() &&
1299  !llvm::all_of(getTagIndices().getTypes(),
1300  [](Type t) { return t.isIndex(); }))
1301  return emitOpError("expected tag indices to be of index type");
1302 
1303  // Optional stride-related operands must be either both present or both
1304  // absent.
1305  if (numOperands != numExpectedOperands &&
1306  numOperands != numExpectedOperands + 2)
1307  return emitOpError("incorrect number of operands");
1308 
1309  // 5. Strides.
1310  if (isStrided()) {
1311  if (!getStride().getType().isIndex() ||
1312  !getNumElementsPerStride().getType().isIndex())
1313  return emitOpError(
1314  "expected stride and num elements per stride to be of type index");
1315  }
1316 
1317  return success();
1318 }
1319 
1320 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1321  SmallVectorImpl<OpFoldResult> &results) {
1322  /// dma_start(memrefcast) -> dma_start
1323  return foldMemRefCast(*this);
1324 }
1325 
1326 // ---------------------------------------------------------------------------
1327 // DmaWaitOp
1328 // ---------------------------------------------------------------------------
1329 
1330 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1331  SmallVectorImpl<OpFoldResult> &results) {
1332  /// dma_wait(memrefcast) -> dma_wait
1333  return foldMemRefCast(*this);
1334 }
1335 
1336 LogicalResult DmaWaitOp::verify() {
1337  // Check that the number of tag indices matches the tagMemRef rank.
1338  unsigned numTagIndices = getTagIndices().size();
1339  unsigned tagMemRefRank = getTagMemRefRank();
1340  if (numTagIndices != tagMemRefRank)
1341  return emitOpError() << "expected tagIndices to have the same number of "
1342  "elements as the tagMemRef rank, expected "
1343  << tagMemRefRank << ", but got " << numTagIndices;
1344  return success();
1345 }
1346 
1347 //===----------------------------------------------------------------------===//
1348 // ExtractAlignedPointerAsIndexOp
1349 //===----------------------------------------------------------------------===//
1350 
1351 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1352  function_ref<void(Value, StringRef)> setNameFn) {
1353  setNameFn(getResult(), "intptr");
1354 }
1355 
1356 //===----------------------------------------------------------------------===//
1357 // ExtractStridedMetadataOp
1358 //===----------------------------------------------------------------------===//
1359 
1360 /// The number and type of the results are inferred from the
1361 /// shape of the source.
1362 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1363  MLIRContext *context, std::optional<Location> location,
1364  ExtractStridedMetadataOp::Adaptor adaptor,
1365  SmallVectorImpl<Type> &inferredReturnTypes) {
1366  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1367  if (!sourceType)
1368  return failure();
1369 
1370  unsigned sourceRank = sourceType.getRank();
1371  IndexType indexType = IndexType::get(context);
1372  auto memrefType =
1373  MemRefType::get({}, sourceType.getElementType(),
1374  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1375  // Base.
1376  inferredReturnTypes.push_back(memrefType);
1377  // Offset.
1378  inferredReturnTypes.push_back(indexType);
1379  // Sizes and strides.
1380  for (unsigned i = 0; i < sourceRank * 2; ++i)
1381  inferredReturnTypes.push_back(indexType);
1382  return success();
1383 }
1384 
1385 void ExtractStridedMetadataOp::getAsmResultNames(
1386  function_ref<void(Value, StringRef)> setNameFn) {
1387  setNameFn(getBaseBuffer(), "base_buffer");
1388  setNameFn(getOffset(), "offset");
1389  // For multi-result to work properly with pretty names and packed syntax `x:3`
1390  // we can only give a pretty name to the first value in the pack.
1391  if (!getSizes().empty()) {
1392  setNameFn(getSizes().front(), "sizes");
1393  setNameFn(getStrides().front(), "strides");
1394  }
1395 }
1396 
1397 /// Helper function to perform the replacement of all constant uses of `values`
1398 /// by a materialized constant extracted from `maybeConstants`.
1399 /// `values` and `maybeConstants` are expected to have the same size.
1400 template <typename Container>
1401 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1402  Container values,
1403  ArrayRef<OpFoldResult> maybeConstants) {
1404  assert(values.size() == maybeConstants.size() &&
1405  " expected values and maybeConstants of the same size");
1406  bool atLeastOneReplacement = false;
1407  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1408  // Don't materialize a constant if there are no uses: this would indice
1409  // infinite loops in the driver.
1410  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1411  continue;
1412  assert(maybeConstant.template is<Attribute>() &&
1413  "The constified value should be either unchanged (i.e., == result) "
1414  "or a constant");
1415  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1416  loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
1417  .getInt());
1418  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1419  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1420  // yet.
1421  op->replaceUsesOfWith(result, constantVal);
1422  atLeastOneReplacement = true;
1423  }
1424  }
1425  return atLeastOneReplacement;
1426 }
1427 
1428 LogicalResult
1429 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1430  SmallVectorImpl<OpFoldResult> &results) {
1431  OpBuilder builder(*this);
1432 
1433  bool atLeastOneReplacement = replaceConstantUsesOf(
1434  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1435  getConstifiedMixedOffset());
1436  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1437  getConstifiedMixedSizes());
1438  atLeastOneReplacement |= replaceConstantUsesOf(
1439  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1440 
1441  return success(atLeastOneReplacement);
1442 }
1443 
1444 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1445  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1446  constifyIndexValues(values, getSource().getType(), getContext(),
1447  getConstantSizes, ShapedType::isDynamic);
1448  return values;
1449 }
1450 
1452 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1453  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1454  constifyIndexValues(values, getSource().getType(), getContext(),
1455  getConstantStrides, ShapedType::isDynamic);
1456  return values;
1457 }
1458 
1459 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1460  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1461  SmallVector<OpFoldResult> values(1, offsetOfr);
1462  constifyIndexValues(values, getSource().getType(), getContext(),
1463  getConstantOffset, ShapedType::isDynamic);
1464  return values[0];
1465 }
1466 
1467 //===----------------------------------------------------------------------===//
1468 // GenericAtomicRMWOp
1469 //===----------------------------------------------------------------------===//
1470 
1471 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1472  Value memref, ValueRange ivs) {
1473  OpBuilder::InsertionGuard g(builder);
1474  result.addOperands(memref);
1475  result.addOperands(ivs);
1476 
1477  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1478  Type elementType = memrefType.getElementType();
1479  result.addTypes(elementType);
1480 
1481  Region *bodyRegion = result.addRegion();
1482  builder.createBlock(bodyRegion);
1483  bodyRegion->addArgument(elementType, memref.getLoc());
1484  }
1485 }
1486 
1487 LogicalResult GenericAtomicRMWOp::verify() {
1488  auto &body = getRegion();
1489  if (body.getNumArguments() != 1)
1490  return emitOpError("expected single number of entry block arguments");
1491 
1492  if (getResult().getType() != body.getArgument(0).getType())
1493  return emitOpError("expected block argument of the same type result type");
1494 
1495  bool hasSideEffects =
1496  body.walk([&](Operation *nestedOp) {
1497  if (isMemoryEffectFree(nestedOp))
1498  return WalkResult::advance();
1499  nestedOp->emitError(
1500  "body of 'memref.generic_atomic_rmw' should contain "
1501  "only operations with no side effects");
1502  return WalkResult::interrupt();
1503  })
1504  .wasInterrupted();
1505  return hasSideEffects ? failure() : success();
1506 }
1507 
1508 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1509  OperationState &result) {
1511  Type memrefType;
1513 
1514  Type indexType = parser.getBuilder().getIndexType();
1515  if (parser.parseOperand(memref) ||
1517  parser.parseColonType(memrefType) ||
1518  parser.resolveOperand(memref, memrefType, result.operands) ||
1519  parser.resolveOperands(ivs, indexType, result.operands))
1520  return failure();
1521 
1522  Region *body = result.addRegion();
1523  if (parser.parseRegion(*body, {}) ||
1524  parser.parseOptionalAttrDict(result.attributes))
1525  return failure();
1526  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1527  return success();
1528 }
1529 
1531  p << ' ' << getMemref() << "[" << getIndices()
1532  << "] : " << getMemref().getType() << ' ';
1533  p.printRegion(getRegion());
1534  p.printOptionalAttrDict((*this)->getAttrs());
1535 }
1536 
1537 //===----------------------------------------------------------------------===//
1538 // AtomicYieldOp
1539 //===----------------------------------------------------------------------===//
1540 
1541 LogicalResult AtomicYieldOp::verify() {
1542  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1543  Type resultType = getResult().getType();
1544  if (parentType != resultType)
1545  return emitOpError() << "types mismatch between yield op: " << resultType
1546  << " and its parent: " << parentType;
1547  return success();
1548 }
1549 
1550 //===----------------------------------------------------------------------===//
1551 // GlobalOp
1552 //===----------------------------------------------------------------------===//
1553 
1555  TypeAttr type,
1556  Attribute initialValue) {
1557  p << type;
1558  if (!op.isExternal()) {
1559  p << " = ";
1560  if (op.isUninitialized())
1561  p << "uninitialized";
1562  else
1563  p.printAttributeWithoutType(initialValue);
1564  }
1565 }
1566 
1567 static ParseResult
1569  Attribute &initialValue) {
1570  Type type;
1571  if (parser.parseType(type))
1572  return failure();
1573 
1574  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1575  if (!memrefType || !memrefType.hasStaticShape())
1576  return parser.emitError(parser.getNameLoc())
1577  << "type should be static shaped memref, but got " << type;
1578  typeAttr = TypeAttr::get(type);
1579 
1580  if (parser.parseOptionalEqual())
1581  return success();
1582 
1583  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1584  initialValue = UnitAttr::get(parser.getContext());
1585  return success();
1586  }
1587 
1588  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1589  if (parser.parseAttribute(initialValue, tensorType))
1590  return failure();
1591  if (!llvm::isa<ElementsAttr>(initialValue))
1592  return parser.emitError(parser.getNameLoc())
1593  << "initial value should be a unit or elements attribute";
1594  return success();
1595 }
1596 
1597 LogicalResult GlobalOp::verify() {
1598  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1599  if (!memrefType || !memrefType.hasStaticShape())
1600  return emitOpError("type should be static shaped memref, but got ")
1601  << getType();
1602 
1603  // Verify that the initial value, if present, is either a unit attribute or
1604  // an elements attribute.
1605  if (getInitialValue().has_value()) {
1606  Attribute initValue = getInitialValue().value();
1607  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1608  return emitOpError("initial value should be a unit or elements "
1609  "attribute, but got ")
1610  << initValue;
1611 
1612  // Check that the type of the initial value is compatible with the type of
1613  // the global variable.
1614  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1615  Type initType = elementsAttr.getType();
1616  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1617  if (initType != tensorType)
1618  return emitOpError("initial value expected to be of type ")
1619  << tensorType << ", but was of type " << initType;
1620  }
1621  }
1622 
1623  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1624  uint64_t alignment = *alignAttr;
1625 
1626  if (!llvm::isPowerOf2_64(alignment))
1627  return emitError() << "alignment attribute value " << alignment
1628  << " is not a power of 2";
1629  }
1630 
1631  // TODO: verify visibility for declarations.
1632  return success();
1633 }
1634 
1635 ElementsAttr GlobalOp::getConstantInitValue() {
1636  auto initVal = getInitialValue();
1637  if (getConstant() && initVal.has_value())
1638  return llvm::cast<ElementsAttr>(initVal.value());
1639  return {};
1640 }
1641 
1642 //===----------------------------------------------------------------------===//
1643 // GetGlobalOp
1644 //===----------------------------------------------------------------------===//
1645 
1646 LogicalResult
1647 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1648  // Verify that the result type is same as the type of the referenced
1649  // memref.global op.
1650  auto global =
1651  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1652  if (!global)
1653  return emitOpError("'")
1654  << getName() << "' does not reference a valid global memref";
1655 
1656  Type resultType = getResult().getType();
1657  if (global.getType() != resultType)
1658  return emitOpError("result type ")
1659  << resultType << " does not match type " << global.getType()
1660  << " of the global memref @" << getName();
1661  return success();
1662 }
1663 
1664 //===----------------------------------------------------------------------===//
1665 // LoadOp
1666 //===----------------------------------------------------------------------===//
1667 
1668 LogicalResult LoadOp::verify() {
1669  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1670  return emitOpError("incorrect number of indices for load, expected ")
1671  << getMemRefType().getRank() << " but got " << getIndices().size();
1672  }
1673  return success();
1674 }
1675 
1676 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1677  /// load(memrefcast) -> load
1678  if (succeeded(foldMemRefCast(*this)))
1679  return getResult();
1680  return OpFoldResult();
1681 }
1682 
1683 //===----------------------------------------------------------------------===//
1684 // MemorySpaceCastOp
1685 //===----------------------------------------------------------------------===//
1686 
1687 void MemorySpaceCastOp::getAsmResultNames(
1688  function_ref<void(Value, StringRef)> setNameFn) {
1689  setNameFn(getResult(), "memspacecast");
1690 }
1691 
1692 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1693  if (inputs.size() != 1 || outputs.size() != 1)
1694  return false;
1695  Type a = inputs.front(), b = outputs.front();
1696  auto aT = llvm::dyn_cast<MemRefType>(a);
1697  auto bT = llvm::dyn_cast<MemRefType>(b);
1698 
1699  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1700  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1701 
1702  if (aT && bT) {
1703  if (aT.getElementType() != bT.getElementType())
1704  return false;
1705  if (aT.getLayout() != bT.getLayout())
1706  return false;
1707  if (aT.getShape() != bT.getShape())
1708  return false;
1709  return true;
1710  }
1711  if (uaT && ubT) {
1712  return uaT.getElementType() == ubT.getElementType();
1713  }
1714  return false;
1715 }
1716 
1717 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1718  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1719  // t2)
1720  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1721  getSourceMutable().assign(parentCast.getSource());
1722  return getResult();
1723  }
1724  return Value{};
1725 }
1726 
1727 //===----------------------------------------------------------------------===//
1728 // PrefetchOp
1729 //===----------------------------------------------------------------------===//
1730 
1732  p << " " << getMemref() << '[';
1734  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1735  p << ", locality<" << getLocalityHint();
1736  p << ">, " << (getIsDataCache() ? "data" : "instr");
1738  (*this)->getAttrs(),
1739  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1740  p << " : " << getMemRefType();
1741 }
1742 
1743 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1744  OpAsmParser::UnresolvedOperand memrefInfo;
1746  IntegerAttr localityHint;
1747  MemRefType type;
1748  StringRef readOrWrite, cacheType;
1749 
1750  auto indexTy = parser.getBuilder().getIndexType();
1751  auto i32Type = parser.getBuilder().getIntegerType(32);
1752  if (parser.parseOperand(memrefInfo) ||
1753  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1754  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1755  parser.parseComma() || parser.parseKeyword("locality") ||
1756  parser.parseLess() ||
1757  parser.parseAttribute(localityHint, i32Type, "localityHint",
1758  result.attributes) ||
1759  parser.parseGreater() || parser.parseComma() ||
1760  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1761  parser.resolveOperand(memrefInfo, type, result.operands) ||
1762  parser.resolveOperands(indexInfo, indexTy, result.operands))
1763  return failure();
1764 
1765  if (readOrWrite != "read" && readOrWrite != "write")
1766  return parser.emitError(parser.getNameLoc(),
1767  "rw specifier has to be 'read' or 'write'");
1768  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1769  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1770 
1771  if (cacheType != "data" && cacheType != "instr")
1772  return parser.emitError(parser.getNameLoc(),
1773  "cache type has to be 'data' or 'instr'");
1774 
1775  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1776  parser.getBuilder().getBoolAttr(cacheType == "data"));
1777 
1778  return success();
1779 }
1780 
1781 LogicalResult PrefetchOp::verify() {
1782  if (getNumOperands() != 1 + getMemRefType().getRank())
1783  return emitOpError("too few indices");
1784 
1785  return success();
1786 }
1787 
1788 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1789  SmallVectorImpl<OpFoldResult> &results) {
1790  // prefetch(memrefcast) -> prefetch
1791  return foldMemRefCast(*this);
1792 }
1793 
1794 //===----------------------------------------------------------------------===//
1795 // RankOp
1796 //===----------------------------------------------------------------------===//
1797 
1798 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1799  // Constant fold rank when the rank of the operand is known.
1800  auto type = getOperand().getType();
1801  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1802  if (shapedType && shapedType.hasRank())
1803  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1804  return IntegerAttr();
1805 }
1806 
1807 //===----------------------------------------------------------------------===//
1808 // ReinterpretCastOp
1809 //===----------------------------------------------------------------------===//
1810 
1811 void ReinterpretCastOp::getAsmResultNames(
1812  function_ref<void(Value, StringRef)> setNameFn) {
1813  setNameFn(getResult(), "reinterpret_cast");
1814 }
1815 
1816 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1817 /// `staticSizes` and `staticStrides` are automatically filled with
1818 /// source-memref-rank sentinel values that encode dynamic entries.
1819 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1820  MemRefType resultType, Value source,
1821  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1822  ArrayRef<OpFoldResult> strides,
1823  ArrayRef<NamedAttribute> attrs) {
1824  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1825  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1826  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1827  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1828  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1829  result.addAttributes(attrs);
1830  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1831  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1832  b.getDenseI64ArrayAttr(staticSizes),
1833  b.getDenseI64ArrayAttr(staticStrides));
1834 }
1835 
1836 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1837  MemRefType resultType, Value source,
1838  int64_t offset, ArrayRef<int64_t> sizes,
1839  ArrayRef<int64_t> strides,
1840  ArrayRef<NamedAttribute> attrs) {
1841  SmallVector<OpFoldResult> sizeValues =
1842  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1843  return b.getI64IntegerAttr(v);
1844  }));
1845  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1846  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1847  return b.getI64IntegerAttr(v);
1848  }));
1849  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1850  strideValues, attrs);
1851 }
1852 
1853 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1854  MemRefType resultType, Value source, Value offset,
1855  ValueRange sizes, ValueRange strides,
1856  ArrayRef<NamedAttribute> attrs) {
1857  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1858  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1859  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1860  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1861  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1862 }
1863 
1864 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1865 // completed automatically, like we have for subview and extract_slice.
1866 LogicalResult ReinterpretCastOp::verify() {
1867  // The source and result memrefs should be in the same memory space.
1868  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1869  auto resultType = llvm::cast<MemRefType>(getType());
1870  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1871  return emitError("different memory spaces specified for source type ")
1872  << srcType << " and result memref type " << resultType;
1873  if (srcType.getElementType() != resultType.getElementType())
1874  return emitError("different element types specified for source type ")
1875  << srcType << " and result memref type " << resultType;
1876 
1877  // Match sizes in result memref type and in static_sizes attribute.
1878  for (auto [idx, resultSize, expectedSize] :
1879  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1880  if (!ShapedType::isDynamic(resultSize) &&
1881  !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1882  return emitError("expected result type with size = ")
1883  << expectedSize << " instead of " << resultSize
1884  << " in dim = " << idx;
1885  }
1886 
1887  // Match offset and strides in static_offset and static_strides attributes. If
1888  // result memref type has no affine map specified, this will assume an
1889  // identity layout.
1890  int64_t resultOffset;
1891  SmallVector<int64_t, 4> resultStrides;
1892  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1893  return emitError("expected result type to have strided layout but found ")
1894  << resultType;
1895 
1896  // Match offset in result memref type and in static_offsets attribute.
1897  int64_t expectedOffset = getStaticOffsets().front();
1898  if (!ShapedType::isDynamic(resultOffset) &&
1899  !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
1900  return emitError("expected result type with offset = ")
1901  << expectedOffset << " instead of " << resultOffset;
1902 
1903  // Match strides in result memref type and in static_strides attribute.
1904  for (auto [idx, resultStride, expectedStride] :
1905  llvm::enumerate(resultStrides, getStaticStrides())) {
1906  if (!ShapedType::isDynamic(resultStride) &&
1907  !ShapedType::isDynamic(expectedStride) &&
1908  resultStride != expectedStride)
1909  return emitError("expected result type with stride = ")
1910  << expectedStride << " instead of " << resultStride
1911  << " in dim = " << idx;
1912  }
1913 
1914  return success();
1915 }
1916 
1917 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1918  Value src = getSource();
1919  auto getPrevSrc = [&]() -> Value {
1920  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1921  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1922  return prev.getSource();
1923 
1924  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1925  if (auto prev = src.getDefiningOp<CastOp>())
1926  return prev.getSource();
1927 
1928  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1929  // are 0.
1930  if (auto prev = src.getDefiningOp<SubViewOp>())
1931  if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1932  return isConstantIntValue(val, 0);
1933  }))
1934  return prev.getSource();
1935 
1936  return nullptr;
1937  };
1938 
1939  if (auto prevSrc = getPrevSrc()) {
1940  getSourceMutable().assign(prevSrc);
1941  return getResult();
1942  }
1943 
1944  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1945  if (!ShapedType::isDynamicShape(getType().getShape()) &&
1946  src.getType() == getType() && getStaticOffsets().front() == 0) {
1947  return src;
1948  }
1949 
1950  return nullptr;
1951 }
1952 
1953 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1956  ShapedType::isDynamic);
1957  return values;
1958 }
1959 
1960 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1961  SmallVector<OpFoldResult> values = getMixedStrides();
1963  ShapedType::isDynamic);
1964  return values;
1965 }
1966 
1967 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1968  SmallVector<OpFoldResult> values = getMixedOffsets();
1969  assert(values.size() == 1 &&
1970  "reinterpret_cast must have one and only one offset");
1972  ShapedType::isDynamic);
1973  return values[0];
1974 }
1975 
1976 namespace {
1977 /// Replace the sequence:
1978 /// ```
1979 /// base, offset, sizes, strides = extract_strided_metadata src
1980 /// dst = reinterpret_cast base to offset, sizes, strides
1981 /// ```
1982 /// With
1983 ///
1984 /// ```
1985 /// dst = memref.cast src
1986 /// ```
1987 ///
1988 /// Note: The cast operation is only inserted when the type of dst and src
1989 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
1990 ///
1991 /// This pattern also matches when the offset, sizes, and strides don't come
1992 /// directly from the `extract_strided_metadata`'s results but it can be
1993 /// statically proven that they would hold the same values.
1994 ///
1995 /// For instance, the following sequence would be replaced:
1996 /// ```
1997 /// base, offset, sizes, strides =
1998 /// extract_strided_metadata memref : memref<3x4xty>
1999 /// dst = reinterpret_cast base to 0, [3, 4], strides
2000 /// ```
2001 /// Because we know (thanks to the type of the input memref) that variable
2002 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
2003 ///
2004 /// Similarly, the following sequence would be replaced:
2005 /// ```
2006 /// c0 = arith.constant 0
2007 /// c4 = arith.constant 4
2008 /// base, offset, sizes, strides =
2009 /// extract_strided_metadata memref : memref<3x4xty>
2010 /// dst = reinterpret_cast base to c0, [3, c4], strides
2011 /// ```
2012 /// Because we know that `offset`and `c0` will hold 0
2013 /// and `c4` will hold 4.
2014 struct ReinterpretCastOpExtractStridedMetadataFolder
2015  : public OpRewritePattern<ReinterpretCastOp> {
2016 public:
2018 
2019  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2020  PatternRewriter &rewriter) const override {
2021  auto extractStridedMetadata =
2022  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2023  if (!extractStridedMetadata)
2024  return failure();
2025  // Check if the reinterpret cast reconstructs a memref with the exact same
2026  // properties as the extract strided metadata.
2027 
2028  // First, check that the strides are the same.
2029  SmallVector<OpFoldResult> extractStridesOfr =
2030  extractStridedMetadata.getConstifiedMixedStrides();
2031  SmallVector<OpFoldResult> reinterpretStridesOfr =
2032  op.getConstifiedMixedStrides();
2033  if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2034  return failure();
2035 
2036  unsigned rank = op.getType().getRank();
2037  for (unsigned i = 0; i < rank; ++i) {
2038  if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2039  return failure();
2040  }
2041 
2042  // Second, check the sizes.
2043  assert(extractStridedMetadata.getSizes().size() ==
2044  op.getMixedSizes().size() &&
2045  "Strides and sizes rank must match");
2046  SmallVector<OpFoldResult> extractSizesOfr =
2047  extractStridedMetadata.getConstifiedMixedSizes();
2048  SmallVector<OpFoldResult> reinterpretSizesOfr =
2049  op.getConstifiedMixedSizes();
2050  for (unsigned i = 0; i < rank; ++i) {
2051  if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2052  return failure();
2053  }
2054  // Finally, check the offset.
2055  assert(op.getMixedOffsets().size() == 1 &&
2056  "reinterpret_cast with more than one offset should have been "
2057  "rejected by the verifier");
2058  OpFoldResult extractOffsetOfr =
2059  extractStridedMetadata.getConstifiedMixedOffset();
2060  OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2061  if (extractOffsetOfr != reinterpretOffsetOfr)
2062  return failure();
2063 
2064  // At this point, we know that the back and forth between extract strided
2065  // metadata and reinterpret cast is a noop. However, the final type of the
2066  // reinterpret cast may not be exactly the same as the original memref.
2067  // E.g., it could be changing a dimension from static to dynamic. Check that
2068  // here and add a cast if necessary.
2069  Type srcTy = extractStridedMetadata.getSource().getType();
2070  if (srcTy == op.getResult().getType())
2071  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2072  else
2073  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2074  extractStridedMetadata.getSource());
2075 
2076  return success();
2077  }
2078 };
2079 } // namespace
2080 
2081 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2082  MLIRContext *context) {
2083  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2084 }
2085 
2086 //===----------------------------------------------------------------------===//
2087 // Reassociative reshape ops
2088 //===----------------------------------------------------------------------===//
2089 
2090 void CollapseShapeOp::getAsmResultNames(
2091  function_ref<void(Value, StringRef)> setNameFn) {
2092  setNameFn(getResult(), "collapse_shape");
2093 }
2094 
2095 void ExpandShapeOp::getAsmResultNames(
2096  function_ref<void(Value, StringRef)> setNameFn) {
2097  setNameFn(getResult(), "expand_shape");
2098 }
2099 
2100 LogicalResult ExpandShapeOp::reifyResultShapes(
2101  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2102  reifiedResultShapes = {
2103  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2104  return success();
2105 }
2106 
2107 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2108 /// result and operand. Layout maps are verified separately.
2109 ///
2110 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2111 /// allowed in a reassocation group.
2112 static LogicalResult
2114  ArrayRef<int64_t> expandedShape,
2115  ArrayRef<ReassociationIndices> reassociation,
2116  bool allowMultipleDynamicDimsPerGroup) {
2117  // There must be one reassociation group per collapsed dimension.
2118  if (collapsedShape.size() != reassociation.size())
2119  return op->emitOpError("invalid number of reassociation groups: found ")
2120  << reassociation.size() << ", expected " << collapsedShape.size();
2121 
2122  // The next expected expanded dimension index (while iterating over
2123  // reassociation indices).
2124  int64_t nextDim = 0;
2125  for (const auto &it : llvm::enumerate(reassociation)) {
2126  ReassociationIndices group = it.value();
2127  int64_t collapsedDim = it.index();
2128 
2129  bool foundDynamic = false;
2130  for (int64_t expandedDim : group) {
2131  if (expandedDim != nextDim++)
2132  return op->emitOpError("reassociation indices must be contiguous");
2133 
2134  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2135  return op->emitOpError("reassociation index ")
2136  << expandedDim << " is out of bounds";
2137 
2138  // Check if there are multiple dynamic dims in a reassociation group.
2139  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2140  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2141  return op->emitOpError(
2142  "at most one dimension in a reassociation group may be dynamic");
2143  foundDynamic = true;
2144  }
2145  }
2146 
2147  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2148  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2149  return op->emitOpError("collapsed dim (")
2150  << collapsedDim
2151  << ") must be dynamic if and only if reassociation group is "
2152  "dynamic";
2153 
2154  // If all dims in the reassociation group are static, the size of the
2155  // collapsed dim can be verified.
2156  if (!foundDynamic) {
2157  int64_t groupSize = 1;
2158  for (int64_t expandedDim : group)
2159  groupSize *= expandedShape[expandedDim];
2160  if (groupSize != collapsedShape[collapsedDim])
2161  return op->emitOpError("collapsed dim size (")
2162  << collapsedShape[collapsedDim]
2163  << ") must equal reassociation group size (" << groupSize << ")";
2164  }
2165  }
2166 
2167  if (collapsedShape.empty()) {
2168  // Rank 0: All expanded dimensions must be 1.
2169  for (int64_t d : expandedShape)
2170  if (d != 1)
2171  return op->emitOpError(
2172  "rank 0 memrefs can only be extended/collapsed with/from ones");
2173  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2174  // Rank >= 1: Number of dimensions among all reassociation groups must match
2175  // the result memref rank.
2176  return op->emitOpError("expanded rank (")
2177  << expandedShape.size()
2178  << ") inconsistent with number of reassociation indices (" << nextDim
2179  << ")";
2180  }
2181 
2182  return success();
2183 }
2184 
2185 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2186  return getSymbolLessAffineMaps(getReassociationExprs());
2187 }
2188 
2189 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2191  getReassociationIndices());
2192 }
2193 
2194 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2195  return getSymbolLessAffineMaps(getReassociationExprs());
2196 }
2197 
2198 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2200  getReassociationIndices());
2201 }
2202 
2203 /// Compute the layout map after expanding a given source MemRef type with the
2204 /// specified reassociation indices.
2205 static FailureOr<StridedLayoutAttr>
2206 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2207  ArrayRef<ReassociationIndices> reassociation) {
2208  int64_t srcOffset;
2209  SmallVector<int64_t> srcStrides;
2210  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2211  return failure();
2212  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2213 
2214  // 1-1 mapping between srcStrides and reassociation packs.
2215  // Each srcStride starts with the given value and gets expanded according to
2216  // the proper entries in resultShape.
2217  // Example:
2218  // srcStrides = [10000, 1 , 100 ],
2219  // reassociations = [ [0], [1], [2, 3, 4]],
2220  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2221  // -> For the purpose of stride calculation, the useful sizes are:
2222  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2223  // resultStrides = [10000, 1, 600, 200, 100]
2224  // Note that a stride does not get expanded along the first entry of each
2225  // shape pack.
2226  SmallVector<int64_t> reverseResultStrides;
2227  reverseResultStrides.reserve(resultShape.size());
2228  unsigned shapeIndex = resultShape.size() - 1;
2229  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2230  ReassociationIndices reassoc = std::get<0>(it);
2231  int64_t currentStrideToExpand = std::get<1>(it);
2232  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2233  reverseResultStrides.push_back(currentStrideToExpand);
2234  currentStrideToExpand =
2235  (SaturatedInteger::wrap(currentStrideToExpand) *
2236  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2237  .asInteger();
2238  }
2239  }
2240  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2241  resultStrides.resize(resultShape.size(), 1);
2242  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2243 }
2244 
2245 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2246  MemRefType srcType, ArrayRef<int64_t> resultShape,
2247  ArrayRef<ReassociationIndices> reassociation) {
2248  if (srcType.getLayout().isIdentity()) {
2249  // If the source is contiguous (i.e., no layout map specified), so is the
2250  // result.
2251  MemRefLayoutAttrInterface layout;
2252  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2253  srcType.getMemorySpace());
2254  }
2255 
2256  // Source may not be contiguous. Compute the layout map.
2257  FailureOr<StridedLayoutAttr> computedLayout =
2258  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2259  if (failed(computedLayout))
2260  return failure();
2261  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2262  srcType.getMemorySpace());
2263 }
2264 
2265 FailureOr<SmallVector<OpFoldResult>>
2266 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2267  MemRefType expandedType,
2268  ArrayRef<ReassociationIndices> reassociation,
2269  ArrayRef<OpFoldResult> inputShape) {
2270  std::optional<SmallVector<OpFoldResult>> outputShape =
2271  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2272  inputShape);
2273  if (!outputShape)
2274  return failure();
2275  return *outputShape;
2276 }
2277 
2278 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2279  Type resultType, Value src,
2280  ArrayRef<ReassociationIndices> reassociation,
2281  ArrayRef<OpFoldResult> outputShape) {
2282  auto [staticOutputShape, dynamicOutputShape] =
2284  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2285  getReassociationIndicesAttribute(builder, reassociation),
2286  dynamicOutputShape, staticOutputShape);
2287 }
2288 
2289 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2290  Type resultType, Value src,
2291  ArrayRef<ReassociationIndices> reassociation) {
2292  SmallVector<OpFoldResult> inputShape =
2293  getMixedSizes(builder, result.location, src);
2294  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2295  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2296  builder, result.location, memrefResultTy, reassociation, inputShape);
2297  // Failure of this assertion usually indicates presence of multiple
2298  // dynamic dimensions in the same reassociation group.
2299  assert(succeeded(outputShape) && "unable to infer output shape");
2300  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2301 }
2302 
2303 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2304  ArrayRef<int64_t> resultShape, Value src,
2305  ArrayRef<ReassociationIndices> reassociation) {
2306  // Only ranked memref source values are supported.
2307  auto srcType = llvm::cast<MemRefType>(src.getType());
2308  FailureOr<MemRefType> resultType =
2309  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2310  // Failure of this assertion usually indicates a problem with the source
2311  // type, e.g., could not get strides/offset.
2312  assert(succeeded(resultType) && "could not compute layout");
2313  build(builder, result, *resultType, src, reassociation);
2314 }
2315 
2316 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2317  ArrayRef<int64_t> resultShape, Value src,
2318  ArrayRef<ReassociationIndices> reassociation,
2319  ArrayRef<OpFoldResult> outputShape) {
2320  // Only ranked memref source values are supported.
2321  auto srcType = llvm::cast<MemRefType>(src.getType());
2322  FailureOr<MemRefType> resultType =
2323  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2324  // Failure of this assertion usually indicates a problem with the source
2325  // type, e.g., could not get strides/offset.
2326  assert(succeeded(resultType) && "could not compute layout");
2327  build(builder, result, *resultType, src, reassociation, outputShape);
2328 }
2329 
2330 LogicalResult ExpandShapeOp::verify() {
2331  MemRefType srcType = getSrcType();
2332  MemRefType resultType = getResultType();
2333 
2334  if (srcType.getRank() > resultType.getRank()) {
2335  auto r0 = srcType.getRank();
2336  auto r1 = resultType.getRank();
2337  return emitOpError("has source rank ")
2338  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2339  << r0 << " > " << r1 << ").";
2340  }
2341 
2342  // Verify result shape.
2343  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2344  resultType.getShape(),
2345  getReassociationIndices(),
2346  /*allowMultipleDynamicDimsPerGroup=*/true)))
2347  return failure();
2348 
2349  // Compute expected result type (including layout map).
2350  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2351  srcType, resultType.getShape(), getReassociationIndices());
2352  if (failed(expectedResultType))
2353  return emitOpError("invalid source layout map");
2354 
2355  // Check actual result type.
2356  if (*expectedResultType != resultType)
2357  return emitOpError("expected expanded type to be ")
2358  << *expectedResultType << " but found " << resultType;
2359 
2360  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2361  return emitOpError("expected number of static shape bounds to be equal to "
2362  "the output rank (")
2363  << resultType.getRank() << ") but found "
2364  << getStaticOutputShape().size() << " inputs instead";
2365 
2366  if ((int64_t)getOutputShape().size() !=
2367  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2368  return emitOpError("mismatch in dynamic dims in output_shape and "
2369  "static_output_shape: static_output_shape has ")
2370  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2371  << " dynamic dims while output_shape has " << getOutputShape().size()
2372  << " values";
2373 
2374  // Verify if provided output shapes are in agreement with output type.
2375  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2376  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2377  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2378  if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2379  return emitOpError("invalid output shape provided at pos ") << pos;
2380  }
2381  }
2382 
2383  return success();
2384 }
2385 
2386 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2387  MLIRContext *context) {
2388  results.add<
2391 }
2392 
2393 /// Compute the layout map after collapsing a given source MemRef type with the
2394 /// specified reassociation indices.
2395 ///
2396 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2397 /// not possible to check this by inspecting a MemRefType in the general case.
2398 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2399 /// be valid (and thus accepted by this function) unless `strict = true`.
2400 static FailureOr<StridedLayoutAttr>
2401 computeCollapsedLayoutMap(MemRefType srcType,
2402  ArrayRef<ReassociationIndices> reassociation,
2403  bool strict = false) {
2404  int64_t srcOffset;
2405  SmallVector<int64_t> srcStrides;
2406  auto srcShape = srcType.getShape();
2407  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2408  return failure();
2409 
2410  // The result stride of a reassociation group is the stride of the last entry
2411  // of the reassociation. (TODO: Should be the minimum stride in the
2412  // reassociation because strides are not necessarily sorted. E.g., when using
2413  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2414  // strides are meaningless and could have any arbitrary value.
2415  SmallVector<int64_t> resultStrides;
2416  resultStrides.reserve(reassociation.size());
2417  for (const ReassociationIndices &reassoc : reassociation) {
2418  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2419  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2420  ref = ref.drop_back();
2421  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2422  resultStrides.push_back(srcStrides[ref.back()]);
2423  } else {
2424  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2425  // the corresponding stride may have to be skipped. (See above comment.)
2426  // Therefore, the result stride cannot be statically determined and must
2427  // be dynamic.
2428  resultStrides.push_back(ShapedType::kDynamic);
2429  }
2430  }
2431 
2432  // Validate that each reassociation group is contiguous.
2433  unsigned resultStrideIndex = resultStrides.size() - 1;
2434  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2435  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2436  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2437  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2438  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2439 
2440  // Both source and result stride must have the same static value. In that
2441  // case, we can be sure, that the dimensions are collapsible (because they
2442  // are contiguous).
2443  // If `strict = false` (default during op verification), we accept cases
2444  // where one or both strides are dynamic. This is best effort: We reject
2445  // ops where obviously non-contiguous dims are collapsed, but accept ops
2446  // where we cannot be sure statically. Such ops may fail at runtime. See
2447  // the op documentation for details.
2448  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2449  if (strict && (stride.saturated || srcStride.saturated))
2450  return failure();
2451 
2452  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2453  return failure();
2454  }
2455  }
2456  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2457 }
2458 
2459 bool CollapseShapeOp::isGuaranteedCollapsible(
2460  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2461  // MemRefs with identity layout are always collapsible.
2462  if (srcType.getLayout().isIdentity())
2463  return true;
2464 
2465  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2466  /*strict=*/true));
2467 }
2468 
2469 MemRefType CollapseShapeOp::computeCollapsedType(
2470  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2471  SmallVector<int64_t> resultShape;
2472  resultShape.reserve(reassociation.size());
2473  for (const ReassociationIndices &group : reassociation) {
2474  auto groupSize = SaturatedInteger::wrap(1);
2475  for (int64_t srcDim : group)
2476  groupSize =
2477  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2478  resultShape.push_back(groupSize.asInteger());
2479  }
2480 
2481  if (srcType.getLayout().isIdentity()) {
2482  // If the source is contiguous (i.e., no layout map specified), so is the
2483  // result.
2484  MemRefLayoutAttrInterface layout;
2485  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2486  srcType.getMemorySpace());
2487  }
2488 
2489  // Source may not be fully contiguous. Compute the layout map.
2490  // Note: Dimensions that are collapsed into a single dim are assumed to be
2491  // contiguous.
2492  FailureOr<StridedLayoutAttr> computedLayout =
2493  computeCollapsedLayoutMap(srcType, reassociation);
2494  assert(succeeded(computedLayout) &&
2495  "invalid source layout map or collapsing non-contiguous dims");
2496  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2497  srcType.getMemorySpace());
2498 }
2499 
2500 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2501  ArrayRef<ReassociationIndices> reassociation,
2502  ArrayRef<NamedAttribute> attrs) {
2503  auto srcType = llvm::cast<MemRefType>(src.getType());
2504  MemRefType resultType =
2505  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2507  getReassociationIndicesAttribute(b, reassociation));
2508  build(b, result, resultType, src, attrs);
2509 }
2510 
2511 LogicalResult CollapseShapeOp::verify() {
2512  MemRefType srcType = getSrcType();
2513  MemRefType resultType = getResultType();
2514 
2515  if (srcType.getRank() < resultType.getRank()) {
2516  auto r0 = srcType.getRank();
2517  auto r1 = resultType.getRank();
2518  return emitOpError("has source rank ")
2519  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2520  << r0 << " < " << r1 << ").";
2521  }
2522 
2523  // Verify result shape.
2524  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2525  srcType.getShape(), getReassociationIndices(),
2526  /*allowMultipleDynamicDimsPerGroup=*/true)))
2527  return failure();
2528 
2529  // Compute expected result type (including layout map).
2530  MemRefType expectedResultType;
2531  if (srcType.getLayout().isIdentity()) {
2532  // If the source is contiguous (i.e., no layout map specified), so is the
2533  // result.
2534  MemRefLayoutAttrInterface layout;
2535  expectedResultType =
2536  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2537  srcType.getMemorySpace());
2538  } else {
2539  // Source may not be fully contiguous. Compute the layout map.
2540  // Note: Dimensions that are collapsed into a single dim are assumed to be
2541  // contiguous.
2542  FailureOr<StridedLayoutAttr> computedLayout =
2543  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2544  if (failed(computedLayout))
2545  return emitOpError(
2546  "invalid source layout map or collapsing non-contiguous dims");
2547  expectedResultType =
2548  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2549  *computedLayout, srcType.getMemorySpace());
2550  }
2551 
2552  if (expectedResultType != resultType)
2553  return emitOpError("expected collapsed type to be ")
2554  << expectedResultType << " but found " << resultType;
2555 
2556  return success();
2557 }
2558 
2560  : public OpRewritePattern<CollapseShapeOp> {
2561 public:
2563 
2564  LogicalResult matchAndRewrite(CollapseShapeOp op,
2565  PatternRewriter &rewriter) const override {
2566  auto cast = op.getOperand().getDefiningOp<CastOp>();
2567  if (!cast)
2568  return failure();
2569 
2570  if (!CastOp::canFoldIntoConsumerOp(cast))
2571  return failure();
2572 
2573  Type newResultType = CollapseShapeOp::computeCollapsedType(
2574  llvm::cast<MemRefType>(cast.getOperand().getType()),
2575  op.getReassociationIndices());
2576 
2577  if (newResultType == op.getResultType()) {
2578  rewriter.modifyOpInPlace(
2579  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2580  } else {
2581  Value newOp = rewriter.create<CollapseShapeOp>(
2582  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2583  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2584  }
2585  return success();
2586  }
2587 };
2588 
2589 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2590  MLIRContext *context) {
2591  results.add<
2593  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2594  memref::DimOp, MemRefType>,
2596 }
2597 
2598 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2599  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2600  adaptor.getOperands());
2601 }
2602 
2603 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2604  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2605  adaptor.getOperands());
2606 }
2607 
2608 //===----------------------------------------------------------------------===//
2609 // ReshapeOp
2610 //===----------------------------------------------------------------------===//
2611 
2612 void ReshapeOp::getAsmResultNames(
2613  function_ref<void(Value, StringRef)> setNameFn) {
2614  setNameFn(getResult(), "reshape");
2615 }
2616 
2617 LogicalResult ReshapeOp::verify() {
2618  Type operandType = getSource().getType();
2619  Type resultType = getResult().getType();
2620 
2621  Type operandElementType =
2622  llvm::cast<ShapedType>(operandType).getElementType();
2623  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2624  if (operandElementType != resultElementType)
2625  return emitOpError("element types of source and destination memref "
2626  "types should be the same");
2627 
2628  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2629  if (!operandMemRefType.getLayout().isIdentity())
2630  return emitOpError("source memref type should have identity affine map");
2631 
2632  int64_t shapeSize =
2633  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2634  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2635  if (resultMemRefType) {
2636  if (!resultMemRefType.getLayout().isIdentity())
2637  return emitOpError("result memref type should have identity affine map");
2638  if (shapeSize == ShapedType::kDynamic)
2639  return emitOpError("cannot use shape operand with dynamic length to "
2640  "reshape to statically-ranked memref type");
2641  if (shapeSize != resultMemRefType.getRank())
2642  return emitOpError(
2643  "length of shape operand differs from the result's memref rank");
2644  }
2645  return success();
2646 }
2647 
2648 //===----------------------------------------------------------------------===//
2649 // StoreOp
2650 //===----------------------------------------------------------------------===//
2651 
2652 LogicalResult StoreOp::verify() {
2653  if (getNumOperands() != 2 + getMemRefType().getRank())
2654  return emitOpError("store index operand count not equal to memref rank");
2655 
2656  return success();
2657 }
2658 
2659 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2660  SmallVectorImpl<OpFoldResult> &results) {
2661  /// store(memrefcast) -> store
2662  return foldMemRefCast(*this, getValueToStore());
2663 }
2664 
2665 //===----------------------------------------------------------------------===//
2666 // SubViewOp
2667 //===----------------------------------------------------------------------===//
2668 
2669 void SubViewOp::getAsmResultNames(
2670  function_ref<void(Value, StringRef)> setNameFn) {
2671  setNameFn(getResult(), "subview");
2672 }
2673 
2674 /// A subview result type can be fully inferred from the source type and the
2675 /// static representation of offsets, sizes and strides. Special sentinels
2676 /// encode the dynamic case.
2677 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2678  ArrayRef<int64_t> staticOffsets,
2679  ArrayRef<int64_t> staticSizes,
2680  ArrayRef<int64_t> staticStrides) {
2681  unsigned rank = sourceMemRefType.getRank();
2682  (void)rank;
2683  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2684  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2685  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2686 
2687  // Extract source offset and strides.
2688  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
2689 
2690  // Compute target offset whose value is:
2691  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2692  int64_t targetOffset = sourceOffset;
2693  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2694  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2695  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2696  SaturatedInteger::wrap(staticOffset) *
2697  SaturatedInteger::wrap(sourceStride))
2698  .asInteger();
2699  }
2700 
2701  // Compute target stride whose value is:
2702  // `sourceStrides_i * staticStrides_i`.
2703  SmallVector<int64_t, 4> targetStrides;
2704  targetStrides.reserve(staticOffsets.size());
2705  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2706  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2707  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2708  SaturatedInteger::wrap(staticStride))
2709  .asInteger());
2710  }
2711 
2712  // The type is now known.
2713  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2714  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2715  targetOffset, targetStrides),
2716  sourceMemRefType.getMemorySpace());
2717 }
2718 
2719 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2720  ArrayRef<OpFoldResult> offsets,
2721  ArrayRef<OpFoldResult> sizes,
2722  ArrayRef<OpFoldResult> strides) {
2723  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2724  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2725  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2726  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2727  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2728  if (!hasValidSizesOffsets(staticOffsets))
2729  return {};
2730  if (!hasValidSizesOffsets(staticSizes))
2731  return {};
2732  if (!hasValidStrides(staticStrides))
2733  return {};
2734  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2735  staticSizes, staticStrides);
2736 }
2737 
2738 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2739  MemRefType sourceRankedTensorType,
2740  ArrayRef<int64_t> offsets,
2741  ArrayRef<int64_t> sizes,
2742  ArrayRef<int64_t> strides) {
2743  auto inferredType = llvm::cast<MemRefType>(
2744  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2745  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2746  "expected ");
2747  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2748  return inferredType;
2749 
2750  // Compute which dimensions are dropped.
2751  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2752  computeRankReductionMask(inferredType.getShape(), resultShape);
2753  assert(dimsToProject.has_value() && "invalid rank reduction");
2754 
2755  // Compute the layout and result type.
2756  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2757  SmallVector<int64_t> rankReducedStrides;
2758  rankReducedStrides.reserve(resultShape.size());
2759  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2760  if (!dimsToProject->contains(idx))
2761  rankReducedStrides.push_back(value);
2762  }
2763  return MemRefType::get(resultShape, inferredType.getElementType(),
2764  StridedLayoutAttr::get(inferredLayout.getContext(),
2765  inferredLayout.getOffset(),
2766  rankReducedStrides),
2767  inferredType.getMemorySpace());
2768 }
2769 
2770 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2771  MemRefType sourceRankedTensorType,
2772  ArrayRef<OpFoldResult> offsets,
2773  ArrayRef<OpFoldResult> sizes,
2774  ArrayRef<OpFoldResult> strides) {
2775  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2776  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2777  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2778  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2779  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2780  return SubViewOp::inferRankReducedResultType(
2781  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2782  staticStrides);
2783 }
2784 
2785 // Build a SubViewOp with mixed static and dynamic entries and custom result
2786 // type. If the type passed is nullptr, it is inferred.
2787 void SubViewOp::build(OpBuilder &b, OperationState &result,
2788  MemRefType resultType, Value source,
2789  ArrayRef<OpFoldResult> offsets,
2790  ArrayRef<OpFoldResult> sizes,
2791  ArrayRef<OpFoldResult> strides,
2792  ArrayRef<NamedAttribute> attrs) {
2793  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2794  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2795  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2796  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2797  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2798  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2799  // Structuring implementation this way avoids duplication between builders.
2800  if (!resultType) {
2801  resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2802  sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2803  }
2804  result.addAttributes(attrs);
2805  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2806  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2807  b.getDenseI64ArrayAttr(staticSizes),
2808  b.getDenseI64ArrayAttr(staticStrides));
2809 }
2810 
2811 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2812 // type.
2813 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2814  ArrayRef<OpFoldResult> offsets,
2815  ArrayRef<OpFoldResult> sizes,
2816  ArrayRef<OpFoldResult> strides,
2817  ArrayRef<NamedAttribute> attrs) {
2818  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2819 }
2820 
2821 // Build a SubViewOp with static entries and inferred result type.
2822 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2823  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2824  ArrayRef<int64_t> strides,
2825  ArrayRef<NamedAttribute> attrs) {
2826  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2827  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2828  return b.getI64IntegerAttr(v);
2829  }));
2830  SmallVector<OpFoldResult> sizeValues =
2831  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2832  return b.getI64IntegerAttr(v);
2833  }));
2834  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2835  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2836  return b.getI64IntegerAttr(v);
2837  }));
2838  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2839 }
2840 
2841 // Build a SubViewOp with dynamic entries and custom result type. If the
2842 // type passed is nullptr, it is inferred.
2843 void SubViewOp::build(OpBuilder &b, OperationState &result,
2844  MemRefType resultType, Value source,
2845  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2846  ArrayRef<int64_t> strides,
2847  ArrayRef<NamedAttribute> attrs) {
2848  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2849  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2850  return b.getI64IntegerAttr(v);
2851  }));
2852  SmallVector<OpFoldResult> sizeValues =
2853  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2854  return b.getI64IntegerAttr(v);
2855  }));
2856  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2857  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2858  return b.getI64IntegerAttr(v);
2859  }));
2860  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2861  attrs);
2862 }
2863 
2864 // Build a SubViewOp with dynamic entries and custom result type. If the type
2865 // passed is nullptr, it is inferred.
2866 void SubViewOp::build(OpBuilder &b, OperationState &result,
2867  MemRefType resultType, Value source, ValueRange offsets,
2868  ValueRange sizes, ValueRange strides,
2869  ArrayRef<NamedAttribute> attrs) {
2870  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2871  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2872  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2873  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2874  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2875  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2876  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2877 }
2878 
2879 // Build a SubViewOp with dynamic entries and inferred result type.
2880 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2881  ValueRange offsets, ValueRange sizes, ValueRange strides,
2882  ArrayRef<NamedAttribute> attrs) {
2883  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2884 }
2885 
2886 /// For ViewLikeOpInterface.
2887 Value SubViewOp::getViewSource() { return getSource(); }
2888 
2889 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2890 /// static value).
2891 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2892  int64_t t1Offset, t2Offset;
2893  SmallVector<int64_t> t1Strides, t2Strides;
2894  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2895  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2896  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2897 }
2898 
2899 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2900 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2901 /// marked as dropped in `droppedDims`.
2902 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2903  const llvm::SmallBitVector &droppedDims) {
2904  assert(size_t(t1.getRank()) == droppedDims.size() &&
2905  "incorrect number of bits");
2906  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2907  "incorrect number of dropped dims");
2908  int64_t t1Offset, t2Offset;
2909  SmallVector<int64_t> t1Strides, t2Strides;
2910  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2911  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2912  if (failed(res1) || failed(res2))
2913  return false;
2914  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2915  if (droppedDims[i])
2916  continue;
2917  if (t1Strides[i] != t2Strides[j])
2918  return false;
2919  ++j;
2920  }
2921  return true;
2922 }
2923 
2925  Operation *op, Type expectedType) {
2926  auto memrefType = llvm::cast<ShapedType>(expectedType);
2927  switch (result) {
2929  return success();
2931  return op->emitError("expected result rank to be smaller or equal to ")
2932  << "the source rank. ";
2934  return op->emitError("expected result type to be ")
2935  << expectedType
2936  << " or a rank-reduced version. (mismatch of result sizes) ";
2938  return op->emitError("expected result element type to be ")
2939  << memrefType.getElementType();
2941  return op->emitError("expected result and source memory spaces to match.");
2943  return op->emitError("expected result type to be ")
2944  << expectedType
2945  << " or a rank-reduced version. (mismatch of result layout) ";
2946  }
2947  llvm_unreachable("unexpected subview verification result");
2948 }
2949 
2950 /// Verifier for SubViewOp.
2951 LogicalResult SubViewOp::verify() {
2952  MemRefType baseType = getSourceType();
2953  MemRefType subViewType = getType();
2954 
2955  // The base memref and the view memref should be in the same memory space.
2956  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2957  return emitError("different memory spaces specified for base memref "
2958  "type ")
2959  << baseType << " and subview memref type " << subViewType;
2960 
2961  // Verify that the base memref type has a strided layout map.
2962  if (!isStrided(baseType))
2963  return emitError("base type ") << baseType << " is not strided";
2964 
2965  // Compute the expected result type, assuming that there are no rank
2966  // reductions.
2967  auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2968  baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2969 
2970  // Verify all properties of a shaped type: rank, element type and dimension
2971  // sizes. This takes into account potential rank reductions.
2972  auto shapedTypeVerification = isRankReducedType(
2973  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2974  if (shapedTypeVerification != SliceVerificationResult::Success)
2975  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2976 
2977  // Make sure that the memory space did not change.
2978  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2980  *this, expectedType);
2981 
2982  // Verify the offset of the layout map.
2983  if (!haveCompatibleOffsets(expectedType, subViewType))
2985  *this, expectedType);
2986 
2987  // The only thing that's left to verify now are the strides. First, compute
2988  // the unused dimensions due to rank reductions. We have to look at sizes and
2989  // strides to decide which dimensions were dropped. This function also
2990  // partially verifies strides in case of rank reductions.
2991  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
2992  getMixedSizes());
2993  if (failed(unusedDims))
2995  *this, expectedType);
2996 
2997  // Strides must match.
2998  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3000  *this, expectedType);
3001 
3002  return success();
3003 }
3004 
3005 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
3006  return os << "range " << range.offset << ":" << range.size << ":"
3007  << range.stride;
3008 }
3009 
3010 /// Return the list of Range (i.e. offset, size, stride). Each Range
3011 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3012 /// with `b` at location `loc`.
3013 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3014  OpBuilder &b, Location loc) {
3015  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3016  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3017  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3019  unsigned rank = ranks[0];
3020  res.reserve(rank);
3021  for (unsigned idx = 0; idx < rank; ++idx) {
3022  Value offset =
3023  op.isDynamicOffset(idx)
3024  ? op.getDynamicOffset(idx)
3025  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
3026  Value size =
3027  op.isDynamicSize(idx)
3028  ? op.getDynamicSize(idx)
3029  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
3030  Value stride =
3031  op.isDynamicStride(idx)
3032  ? op.getDynamicStride(idx)
3033  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
3034  res.emplace_back(Range{offset, size, stride});
3035  }
3036  return res;
3037 }
3038 
3039 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3040 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3041 /// the rank of the inferred result type if `currentResultType` is lower rank
3042 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3043 /// together with the result type. In this case, it is important to compute
3044 /// the dropped dimensions using `currentSourceType` whose strides align with
3045 /// `currentResultType`.
3047  MemRefType currentResultType, MemRefType currentSourceType,
3048  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3049  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3050  auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
3051  sourceType, mixedOffsets, mixedSizes, mixedStrides));
3052  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3053  currentSourceType, currentResultType, mixedSizes);
3054  if (failed(unusedDims))
3055  return nullptr;
3056 
3057  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3058  SmallVector<int64_t> shape, strides;
3059  unsigned numDimsAfterReduction =
3060  nonRankReducedType.getRank() - unusedDims->count();
3061  shape.reserve(numDimsAfterReduction);
3062  strides.reserve(numDimsAfterReduction);
3063  for (const auto &[idx, size, stride] :
3064  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3065  nonRankReducedType.getShape(), layout.getStrides())) {
3066  if (unusedDims->test(idx))
3067  continue;
3068  shape.push_back(size);
3069  strides.push_back(stride);
3070  }
3071 
3072  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3073  StridedLayoutAttr::get(sourceType.getContext(),
3074  layout.getOffset(), strides),
3075  nonRankReducedType.getMemorySpace());
3076 }
3077 
3079  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3080  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3081  unsigned rank = memrefType.getRank();
3082  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3083  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3084  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3085  auto targetType =
3086  llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
3087  targetShape, memrefType, offsets, sizes, strides));
3088  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3089  sizes, strides);
3090 }
3091 
3092 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3093  Value value,
3094  ArrayRef<int64_t> desiredShape) {
3095  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3096  assert(sourceMemrefType && "not a ranked memref type");
3097  auto sourceShape = sourceMemrefType.getShape();
3098  if (sourceShape.equals(desiredShape))
3099  return value;
3100  auto maybeRankReductionMask =
3101  mlir::computeRankReductionMask(sourceShape, desiredShape);
3102  if (!maybeRankReductionMask)
3103  return failure();
3104  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3105 }
3106 
3107 /// Helper method to check if a `subview` operation is trivially a no-op. This
3108 /// is the case if the all offsets are zero, all strides are 1, and the source
3109 /// shape is same as the size of the subview. In such cases, the subview can
3110 /// be folded into its source.
3111 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3112  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3113  return false;
3114 
3115  auto mixedOffsets = subViewOp.getMixedOffsets();
3116  auto mixedSizes = subViewOp.getMixedSizes();
3117  auto mixedStrides = subViewOp.getMixedStrides();
3118 
3119  // Check offsets are zero.
3120  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3121  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3122  return !intValue || intValue.value() != 0;
3123  }))
3124  return false;
3125 
3126  // Check strides are one.
3127  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3128  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3129  return !intValue || intValue.value() != 1;
3130  }))
3131  return false;
3132 
3133  // Check all size values are static and matches the (static) source shape.
3134  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3135  for (const auto &size : llvm::enumerate(mixedSizes)) {
3136  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3137  if (!intValue || *intValue != sourceShape[size.index()])
3138  return false;
3139  }
3140  // All conditions met. The `SubViewOp` is foldable as a no-op.
3141  return true;
3142 }
3143 
3144 namespace {
3145 /// Pattern to rewrite a subview op with MemRefCast arguments.
3146 /// This essentially pushes memref.cast past its consuming subview when
3147 /// `canFoldIntoConsumerOp` is true.
3148 ///
3149 /// Example:
3150 /// ```
3151 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3152 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3153 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3154 /// ```
3155 /// is rewritten into:
3156 /// ```
3157 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3158 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3159 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3160 /// ```
3161 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3162 public:
3164 
3165  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3166  PatternRewriter &rewriter) const override {
3167  // Any constant operand, just return to let SubViewOpConstantFolder kick
3168  // in.
3169  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3170  return matchPattern(operand, matchConstantIndex());
3171  }))
3172  return failure();
3173 
3174  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3175  if (!castOp)
3176  return failure();
3177 
3178  if (!CastOp::canFoldIntoConsumerOp(castOp))
3179  return failure();
3180 
3181  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3182  // the MemRefCastOp source operand type to infer the result type and the
3183  // current SubViewOp source operand type to compute the dropped dimensions
3184  // if the operation is rank-reducing.
3185  auto resultType = getCanonicalSubViewResultType(
3186  subViewOp.getType(), subViewOp.getSourceType(),
3187  llvm::cast<MemRefType>(castOp.getSource().getType()),
3188  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3189  subViewOp.getMixedStrides());
3190  if (!resultType)
3191  return failure();
3192 
3193  Value newSubView = rewriter.create<SubViewOp>(
3194  subViewOp.getLoc(), resultType, castOp.getSource(),
3195  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3196  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3197  subViewOp.getStaticStrides());
3198  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3199  newSubView);
3200  return success();
3201  }
3202 };
3203 
3204 /// Canonicalize subview ops that are no-ops. When the source shape is not
3205 /// same as a result shape due to use of `affine_map`.
3206 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3207 public:
3209 
3210  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3211  PatternRewriter &rewriter) const override {
3212  if (!isTrivialSubViewOp(subViewOp))
3213  return failure();
3214  if (subViewOp.getSourceType() == subViewOp.getType()) {
3215  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3216  return success();
3217  }
3218  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3219  subViewOp.getSource());
3220  return success();
3221  }
3222 };
3223 } // namespace
3224 
3225 /// Return the canonical type of the result of a subview.
3227  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3228  ArrayRef<OpFoldResult> mixedSizes,
3229  ArrayRef<OpFoldResult> mixedStrides) {
3230  // Infer a memref type without taking into account any rank reductions.
3231  auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3232  mixedSizes, mixedStrides);
3233  if (!resTy)
3234  return {};
3235  MemRefType nonReducedType = cast<MemRefType>(resTy);
3236 
3237  // Directly return the non-rank reduced type if there are no dropped dims.
3238  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3239  if (droppedDims.none())
3240  return nonReducedType;
3241 
3242  // Take the strides and offset from the non-rank reduced type.
3243  auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
3244 
3245  // Drop dims from shape and strides.
3246  SmallVector<int64_t> targetShape;
3247  SmallVector<int64_t> targetStrides;
3248  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3249  if (droppedDims.test(i))
3250  continue;
3251  targetStrides.push_back(nonReducedStrides[i]);
3252  targetShape.push_back(nonReducedType.getDimSize(i));
3253  }
3254 
3255  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3256  StridedLayoutAttr::get(nonReducedType.getContext(),
3257  offset, targetStrides),
3258  nonReducedType.getMemorySpace());
3259  }
3260 };
3261 
3262 /// A canonicalizer wrapper to replace SubViewOps.
3264  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3265  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3266  }
3267 };
3268 
3269 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3270  MLIRContext *context) {
3271  results
3274  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3275 }
3276 
3277 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3278  auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
3279  auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
3280 
3281  if (resultShapedType.hasStaticShape() &&
3282  resultShapedType == sourceShapedType) {
3283  return getViewSource();
3284  }
3285 
3286  // Fold subview(subview(x)), where both subviews have the same size and the
3287  // second subview's offsets are all zero. (I.e., the second subview is a
3288  // no-op.)
3289  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3290  auto srcSizes = srcSubview.getMixedSizes();
3291  auto sizes = getMixedSizes();
3292  auto offsets = getMixedOffsets();
3293  bool allOffsetsZero = llvm::all_of(
3294  offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3295  auto strides = getMixedStrides();
3296  bool allStridesOne = llvm::all_of(
3297  strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3298  bool allSizesSame = llvm::equal(sizes, srcSizes);
3299  if (allOffsetsZero && allStridesOne && allSizesSame &&
3300  resultShapedType == sourceShapedType)
3301  return getViewSource();
3302  }
3303 
3304  return {};
3305 }
3306 
3307 //===----------------------------------------------------------------------===//
3308 // TransposeOp
3309 //===----------------------------------------------------------------------===//
3310 
3311 void TransposeOp::getAsmResultNames(
3312  function_ref<void(Value, StringRef)> setNameFn) {
3313  setNameFn(getResult(), "transpose");
3314 }
3315 
3316 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3317 static MemRefType inferTransposeResultType(MemRefType memRefType,
3318  AffineMap permutationMap) {
3319  auto originalSizes = memRefType.getShape();
3320  auto [originalStrides, offset] = getStridesAndOffset(memRefType);
3321  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3322 
3323  // Compute permuted sizes and strides.
3324  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3325  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3326 
3327  return MemRefType::Builder(memRefType)
3328  .setShape(sizes)
3329  .setLayout(
3330  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3331 }
3332 
3333 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3334  AffineMapAttr permutation,
3335  ArrayRef<NamedAttribute> attrs) {
3336  auto permutationMap = permutation.getValue();
3337  assert(permutationMap);
3338 
3339  auto memRefType = llvm::cast<MemRefType>(in.getType());
3340  // Compute result type.
3341  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3342 
3343  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3344  build(b, result, resultType, in, attrs);
3345 }
3346 
3347 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3349  p << " " << getIn() << " " << getPermutation();
3350  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3351  p << " : " << getIn().getType() << " to " << getType();
3352 }
3353 
3354 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3356  AffineMap permutation;
3357  MemRefType srcType, dstType;
3358  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3359  parser.parseOptionalAttrDict(result.attributes) ||
3360  parser.parseColonType(srcType) ||
3361  parser.resolveOperand(in, srcType, result.operands) ||
3362  parser.parseKeywordType("to", dstType) ||
3363  parser.addTypeToList(dstType, result.types))
3364  return failure();
3365 
3366  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3367  AffineMapAttr::get(permutation));
3368  return success();
3369 }
3370 
3371 LogicalResult TransposeOp::verify() {
3372  if (!getPermutation().isPermutation())
3373  return emitOpError("expected a permutation map");
3374  if (getPermutation().getNumDims() != getIn().getType().getRank())
3375  return emitOpError("expected a permutation map of same rank as the input");
3376 
3377  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3378  auto resultType = llvm::cast<MemRefType>(getType());
3379  auto canonicalResultType = canonicalizeStridedLayout(
3380  inferTransposeResultType(srcType, getPermutation()));
3381 
3382  if (canonicalizeStridedLayout(resultType) != canonicalResultType)
3383  return emitOpError("result type ")
3384  << resultType
3385  << " is not equivalent to the canonical transposed input type "
3386  << canonicalResultType;
3387  return success();
3388 }
3389 
3390 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3391  // First check for identity permutation, we can fold it away if input and
3392  // result types are identical already.
3393  if (getPermutation().isIdentity() && getType() == getIn().getType())
3394  return getIn();
3395  // Fold two consecutive memref.transpose Ops into one by composing their
3396  // permutation maps.
3397  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3398  AffineMap composedPermutation =
3399  getPermutation().compose(otherTransposeOp.getPermutation());
3400  getInMutable().assign(otherTransposeOp.getIn());
3401  setPermutation(composedPermutation);
3402  return getResult();
3403  }
3404  return {};
3405 }
3406 
3407 //===----------------------------------------------------------------------===//
3408 // ViewOp
3409 //===----------------------------------------------------------------------===//
3410 
3411 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3412  setNameFn(getResult(), "view");
3413 }
3414 
3415 LogicalResult ViewOp::verify() {
3416  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3417  auto viewType = getType();
3418 
3419  // The base memref should have identity layout map (or none).
3420  if (!baseType.getLayout().isIdentity())
3421  return emitError("unsupported map for base memref type ") << baseType;
3422 
3423  // The result memref should have identity layout map (or none).
3424  if (!viewType.getLayout().isIdentity())
3425  return emitError("unsupported map for result memref type ") << viewType;
3426 
3427  // The base memref and the view memref should be in the same memory space.
3428  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3429  return emitError("different memory spaces specified for base memref "
3430  "type ")
3431  << baseType << " and view memref type " << viewType;
3432 
3433  // Verify that we have the correct number of sizes for the result type.
3434  unsigned numDynamicDims = viewType.getNumDynamicDims();
3435  if (getSizes().size() != numDynamicDims)
3436  return emitError("incorrect number of size operands for type ") << viewType;
3437 
3438  return success();
3439 }
3440 
3441 Value ViewOp::getViewSource() { return getSource(); }
3442 
3443 namespace {
3444 
3445 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3447 
3448  LogicalResult matchAndRewrite(ViewOp viewOp,
3449  PatternRewriter &rewriter) const override {
3450  // Return if none of the operands are constants.
3451  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3452  return matchPattern(operand, matchConstantIndex());
3453  }))
3454  return failure();
3455 
3456  // Get result memref type.
3457  auto memrefType = viewOp.getType();
3458 
3459  // Get offset from old memref view type 'memRefType'.
3460  int64_t oldOffset;
3461  SmallVector<int64_t, 4> oldStrides;
3462  if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3463  return failure();
3464  assert(oldOffset == 0 && "Expected 0 offset");
3465 
3466  SmallVector<Value, 4> newOperands;
3467 
3468  // Offset cannot be folded into result type.
3469 
3470  // Fold any dynamic dim operands which are produced by a constant.
3471  SmallVector<int64_t, 4> newShapeConstants;
3472  newShapeConstants.reserve(memrefType.getRank());
3473 
3474  unsigned dynamicDimPos = 0;
3475  unsigned rank = memrefType.getRank();
3476  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3477  int64_t dimSize = memrefType.getDimSize(dim);
3478  // If this is already static dimension, keep it.
3479  if (!ShapedType::isDynamic(dimSize)) {
3480  newShapeConstants.push_back(dimSize);
3481  continue;
3482  }
3483  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3484  if (auto constantIndexOp =
3485  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3486  // Dynamic shape dimension will be folded.
3487  newShapeConstants.push_back(constantIndexOp.value());
3488  } else {
3489  // Dynamic shape dimension not folded; copy operand from old memref.
3490  newShapeConstants.push_back(dimSize);
3491  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3492  }
3493  dynamicDimPos++;
3494  }
3495 
3496  // Create new memref type with constant folded dims.
3497  MemRefType newMemRefType =
3498  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3499  // Nothing new, don't fold.
3500  if (newMemRefType == memrefType)
3501  return failure();
3502 
3503  // Create new ViewOp.
3504  auto newViewOp = rewriter.create<ViewOp>(
3505  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3506  viewOp.getByteShift(), newOperands);
3507  // Insert a cast so we have the same type as the old memref type.
3508  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3509  return success();
3510  }
3511 };
3512 
3513 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3515 
3516  LogicalResult matchAndRewrite(ViewOp viewOp,
3517  PatternRewriter &rewriter) const override {
3518  Value memrefOperand = viewOp.getOperand(0);
3519  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3520  if (!memrefCastOp)
3521  return failure();
3522  Value allocOperand = memrefCastOp.getOperand();
3523  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3524  if (!allocOp)
3525  return failure();
3526  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3527  viewOp.getByteShift(),
3528  viewOp.getSizes());
3529  return success();
3530  }
3531 };
3532 
3533 } // namespace
3534 
3535 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3536  MLIRContext *context) {
3537  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3538 }
3539 
3540 //===----------------------------------------------------------------------===//
3541 // AtomicRMWOp
3542 //===----------------------------------------------------------------------===//
3543 
3544 LogicalResult AtomicRMWOp::verify() {
3545  if (getMemRefType().getRank() != getNumOperands() - 2)
3546  return emitOpError(
3547  "expects the number of subscripts to be equal to memref rank");
3548  switch (getKind()) {
3549  case arith::AtomicRMWKind::addf:
3550  case arith::AtomicRMWKind::maximumf:
3551  case arith::AtomicRMWKind::minimumf:
3552  case arith::AtomicRMWKind::mulf:
3553  if (!llvm::isa<FloatType>(getValue().getType()))
3554  return emitOpError() << "with kind '"
3555  << arith::stringifyAtomicRMWKind(getKind())
3556  << "' expects a floating-point type";
3557  break;
3558  case arith::AtomicRMWKind::addi:
3559  case arith::AtomicRMWKind::maxs:
3560  case arith::AtomicRMWKind::maxu:
3561  case arith::AtomicRMWKind::mins:
3562  case arith::AtomicRMWKind::minu:
3563  case arith::AtomicRMWKind::muli:
3564  case arith::AtomicRMWKind::ori:
3565  case arith::AtomicRMWKind::andi:
3566  if (!llvm::isa<IntegerType>(getValue().getType()))
3567  return emitOpError() << "with kind '"
3568  << arith::stringifyAtomicRMWKind(getKind())
3569  << "' expects an integer type";
3570  break;
3571  default:
3572  break;
3573  }
3574  return success();
3575 }
3576 
3577 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3578  /// atomicrmw(memrefcast) -> atomicrmw
3579  if (succeeded(foldMemRefCast(*this, getValue())))
3580  return getResult();
3581  return OpFoldResult();
3582 }
3583 
3584 //===----------------------------------------------------------------------===//
3585 // TableGen'd op method definitions
3586 //===----------------------------------------------------------------------===//
3587 
3588 #define GET_OP_CLASSES
3589 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:71
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
Definition: MemRefOps.cpp:163
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
Definition: MemRefOps.cpp:115
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1554
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2113
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:451
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3317
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:432
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2891
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1401
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3046
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
Definition: MemRefOps.cpp:2924
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
Definition: MemRefOps.cpp:176
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1568
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
Definition: MemRefOps.cpp:155
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:941
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3111
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:474
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2902
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:926
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2206
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2401
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:201
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1545
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:135
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:543
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
IndexType getIndexType()
Definition: Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:201
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:222
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:212
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:227
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:59
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:67
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3078
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:319
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:375
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3013
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:521
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:524
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:481
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:484
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2564
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3263
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3264
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3226
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3227
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.