MLIR  21.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::memref;
29 
30 /// Materialize a single constant operation from a given attribute value with
31 /// the desired resultant type.
33  Attribute value, Type type,
34  Location loc) {
35  return arith::ConstantOp::materialize(builder, value, type, loc);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Common canonicalization pattern support logic
40 //===----------------------------------------------------------------------===//
41 
42 /// This is a common class used for patterns of the form
43 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
44 /// into the root operation directly.
45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46  bool folded = false;
47  for (OpOperand &operand : op->getOpOperands()) {
48  auto cast = operand.get().getDefiningOp<CastOp>();
49  if (cast && operand.get() != inner &&
50  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51  operand.set(cast.getOperand());
52  folded = true;
53  }
54  }
55  return success(folded);
56 }
57 
58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
59 /// type.
61  if (auto memref = llvm::dyn_cast<MemRefType>(type))
62  return RankedTensorType::get(memref.getShape(), memref.getElementType());
63  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
64  return UnrankedTensorType::get(memref.getElementType());
65  return NoneType::get(type.getContext());
66 }
67 
69  int64_t dim) {
70  auto memrefType = llvm::cast<MemRefType>(value.getType());
72  if (memrefType.isDynamicDim(dim))
73  return builder.createOrFold<memref::DimOp>(loc, value, dim);
74 
75  return builder.getIndexAttr(memrefType.getDimSize(dim));
76 }
77 
79  Location loc, Value value) {
80  auto memrefType = llvm::cast<MemRefType>(value.getType());
82  for (int64_t i = 0; i < memrefType.getRank(); ++i)
83  result.push_back(getMixedSize(builder, loc, value, i));
84  return result;
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // Utility functions for propagating static information
89 //===----------------------------------------------------------------------===//
90 
91 /// Helper function that infers the constant values from a list of \p values,
92 /// a \p memRefTy, and another helper function \p getAttributes.
93 /// The inferred constant values replace the related `OpFoldResult` in
94 /// \p values.
95 ///
96 /// \note This function shouldn't be used directly, instead, use the
97 /// `getConstifiedMixedXXX` methods from the related operations.
98 ///
99 /// \p getAttributes retuns a list of potentially constant values, as determined
100 /// by \p isDynamic, from the given \p memRefTy. The returned list must have as
101 /// many elements as \p values or be empty.
102 ///
103 /// E.g., consider the following example:
104 /// ```
105 /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
106 /// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
107 /// ```
108 /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
109 /// Now using this helper function with:
110 /// - `values == [2, %dyn_stride]`,
111 /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
112 /// - `getAttributes == getConstantStrides` (i.e., a wrapper around
113 /// `getStridesAndOffset`), and
114 /// - `isDynamic == ShapedType::isDynamic`
115 /// Will yield: `values == [2, 1]`
117  SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
118  MLIRContext *ctxt,
119  llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
120  llvm::function_ref<bool(int64_t)> isDynamic) {
121  SmallVector<int64_t> constValues = getAttributes(memRefTy);
122  Builder builder(ctxt);
123  for (const auto &it : llvm::enumerate(constValues)) {
124  int64_t constValue = it.value();
125  if (!isDynamic(constValue))
126  values[it.index()] = builder.getIndexAttr(constValue);
127  }
128  for (OpFoldResult &ofr : values) {
129  if (auto attr = dyn_cast<Attribute>(ofr)) {
130  // FIXME: We shouldn't need to do that, but right now, the static indices
131  // are created with the wrong type: `i64` instead of `index`.
132  // As a result, if we were to keep the attribute as is, we may fail to see
133  // that two attributes are equal because one would have the i64 type and
134  // the other the index type.
135  // The alternative would be to create constant indices with getI64Attr in
136  // this and the previous loop, but it doesn't logically make sense (we are
137  // dealing with indices here) and would only strenghten the inconsistency
138  // around how static indices are created (some places use getI64Attr,
139  // others use getIndexAttr).
140  // The workaround here is to stick to the IndexAttr type for all the
141  // values, hence we recreate the attribute even when it is already static
142  // to make sure the type is consistent.
143  ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
144  continue;
145  }
146  std::optional<int64_t> maybeConstant =
147  getConstantIntValue(cast<Value>(ofr));
148  if (maybeConstant)
149  ofr = builder.getIndexAttr(*maybeConstant);
150  }
151 }
152 
153 /// Wrapper around `getShape` that conforms to the function signature
154 /// expected for `getAttributes` in `constifyIndexValues`.
155 static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
156  ArrayRef<int64_t> sizes = memRefTy.getShape();
157  return SmallVector<int64_t>(sizes);
158 }
159 
160 /// Wrapper around `getStridesAndOffset` that returns only the offset and
161 /// conforms to the function signature expected for `getAttributes` in
162 /// `constifyIndexValues`.
163 static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
164  SmallVector<int64_t> strides;
165  int64_t offset;
166  LogicalResult hasStaticInformation =
167  memrefType.getStridesAndOffset(strides, offset);
168  if (failed(hasStaticInformation))
169  return SmallVector<int64_t>();
170  return SmallVector<int64_t>(1, offset);
171 }
172 
173 /// Wrapper around `getStridesAndOffset` that returns only the strides and
174 /// conforms to the function signature expected for `getAttributes` in
175 /// `constifyIndexValues`.
176 static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
177  SmallVector<int64_t> strides;
178  int64_t offset;
179  LogicalResult hasStaticInformation =
180  memrefType.getStridesAndOffset(strides, offset);
181  if (failed(hasStaticInformation))
182  return SmallVector<int64_t>();
183  return strides;
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // AllocOp / AllocaOp
188 //===----------------------------------------------------------------------===//
189 
190 void AllocOp::getAsmResultNames(
191  function_ref<void(Value, StringRef)> setNameFn) {
192  setNameFn(getResult(), "alloc");
193 }
194 
195 void AllocaOp::getAsmResultNames(
196  function_ref<void(Value, StringRef)> setNameFn) {
197  setNameFn(getResult(), "alloca");
198 }
199 
200 template <typename AllocLikeOp>
201 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
202  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
203  "applies to only alloc or alloca");
204  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
205  if (!memRefType)
206  return op.emitOpError("result must be a memref");
207 
208  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
209  return op.emitOpError("dimension operand count does not equal memref "
210  "dynamic dimension count");
211 
212  unsigned numSymbols = 0;
213  if (!memRefType.getLayout().isIdentity())
214  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
215  if (op.getSymbolOperands().size() != numSymbols)
216  return op.emitOpError("symbol operand count does not equal memref symbol "
217  "count: expected ")
218  << numSymbols << ", got " << op.getSymbolOperands().size();
219 
220  return success();
221 }
222 
223 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
224 
225 LogicalResult AllocaOp::verify() {
226  // An alloca op needs to have an ancestor with an allocation scope trait.
227  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
228  return emitOpError(
229  "requires an ancestor op with AutomaticAllocationScope trait");
230 
231  return verifyAllocLikeOp(*this);
232 }
233 
234 namespace {
235 /// Fold constant dimensions into an alloc like operation.
236 template <typename AllocLikeOp>
237 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
239 
240  LogicalResult matchAndRewrite(AllocLikeOp alloc,
241  PatternRewriter &rewriter) const override {
242  // Check to see if any dimensions operands are constants. If so, we can
243  // substitute and drop them.
244  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
245  APInt constSizeArg;
246  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
247  return false;
248  return constSizeArg.isNonNegative();
249  }))
250  return failure();
251 
252  auto memrefType = alloc.getType();
253 
254  // Ok, we have one or more constant operands. Collect the non-constant ones
255  // and keep track of the resultant memref type to build.
256  SmallVector<int64_t, 4> newShapeConstants;
257  newShapeConstants.reserve(memrefType.getRank());
258  SmallVector<Value, 4> dynamicSizes;
259 
260  unsigned dynamicDimPos = 0;
261  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
262  int64_t dimSize = memrefType.getDimSize(dim);
263  // If this is already static dimension, keep it.
264  if (!ShapedType::isDynamic(dimSize)) {
265  newShapeConstants.push_back(dimSize);
266  continue;
267  }
268  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
269  APInt constSizeArg;
270  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
271  constSizeArg.isNonNegative()) {
272  // Dynamic shape dimension will be folded.
273  newShapeConstants.push_back(constSizeArg.getZExtValue());
274  } else {
275  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
276  newShapeConstants.push_back(ShapedType::kDynamic);
277  dynamicSizes.push_back(dynamicSize);
278  }
279  dynamicDimPos++;
280  }
281 
282  // Create new memref type (which will have fewer dynamic dimensions).
283  MemRefType newMemRefType =
284  MemRefType::Builder(memrefType).setShape(newShapeConstants);
285  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
286 
287  // Create and insert the alloc op for the new memref.
288  auto newAlloc = rewriter.create<AllocLikeOp>(
289  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
290  alloc.getAlignmentAttr());
291  // Insert a cast so we have the same type as the old alloc.
292  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
293  return success();
294  }
295 };
296 
297 /// Fold alloc operations with no users or only store and dealloc uses.
298 template <typename T>
299 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
301 
302  LogicalResult matchAndRewrite(T alloc,
303  PatternRewriter &rewriter) const override {
304  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
305  if (auto storeOp = dyn_cast<StoreOp>(op))
306  return storeOp.getValue() == alloc;
307  return !isa<DeallocOp>(op);
308  }))
309  return failure();
310 
311  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
312  rewriter.eraseOp(user);
313 
314  rewriter.eraseOp(alloc);
315  return success();
316  }
317 };
318 } // namespace
319 
320 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
321  MLIRContext *context) {
322  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
323 }
324 
325 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
326  MLIRContext *context) {
327  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
328  context);
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // ReallocOp
333 //===----------------------------------------------------------------------===//
334 
335 LogicalResult ReallocOp::verify() {
336  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
337  MemRefType resultType = getType();
338 
339  // The source memref should have identity layout (or none).
340  if (!sourceType.getLayout().isIdentity())
341  return emitError("unsupported layout for source memref type ")
342  << sourceType;
343 
344  // The result memref should have identity layout (or none).
345  if (!resultType.getLayout().isIdentity())
346  return emitError("unsupported layout for result memref type ")
347  << resultType;
348 
349  // The source memref and the result memref should be in the same memory space.
350  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
351  return emitError("different memory spaces specified for source memref "
352  "type ")
353  << sourceType << " and result memref type " << resultType;
354 
355  // The source memref and the result memref should have the same element type.
356  if (sourceType.getElementType() != resultType.getElementType())
357  return emitError("different element types specified for source memref "
358  "type ")
359  << sourceType << " and result memref type " << resultType;
360 
361  // Verify that we have the dynamic dimension operand when it is needed.
362  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
363  return emitError("missing dimension operand for result type ")
364  << resultType;
365  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
366  return emitError("unnecessary dimension operand for result type ")
367  << resultType;
368 
369  return success();
370 }
371 
372 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
373  MLIRContext *context) {
374  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // AllocaScopeOp
379 //===----------------------------------------------------------------------===//
380 
382  bool printBlockTerminators = false;
383 
384  p << ' ';
385  if (!getResults().empty()) {
386  p << " -> (" << getResultTypes() << ")";
387  printBlockTerminators = true;
388  }
389  p << ' ';
390  p.printRegion(getBodyRegion(),
391  /*printEntryBlockArgs=*/false,
392  /*printBlockTerminators=*/printBlockTerminators);
393  p.printOptionalAttrDict((*this)->getAttrs());
394 }
395 
396 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
397  // Create a region for the body.
398  result.regions.reserve(1);
399  Region *bodyRegion = result.addRegion();
400 
401  // Parse optional results type list.
402  if (parser.parseOptionalArrowTypeList(result.types))
403  return failure();
404 
405  // Parse the body region.
406  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
407  return failure();
408  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
409  result.location);
410 
411  // Parse the optional attribute list.
412  if (parser.parseOptionalAttrDict(result.attributes))
413  return failure();
414 
415  return success();
416 }
417 
418 void AllocaScopeOp::getSuccessorRegions(
420  if (!point.isParent()) {
421  regions.push_back(RegionSuccessor(getResults()));
422  return;
423  }
424 
425  regions.push_back(RegionSuccessor(&getBodyRegion()));
426 }
427 
428 /// Given an operation, return whether this op is guaranteed to
429 /// allocate an AutomaticAllocationScopeResource
431  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
432  if (!interface)
433  return false;
434  for (auto res : op->getResults()) {
435  if (auto effect =
436  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
437  if (isa<SideEffects::AutomaticAllocationScopeResource>(
438  effect->getResource()))
439  return true;
440  }
441  }
442  return false;
443 }
444 
445 /// Given an operation, return whether this op itself could
446 /// allocate an AutomaticAllocationScopeResource. Note that
447 /// this will not check whether an operation contained within
448 /// the op can allocate.
450  // This op itself doesn't create a stack allocation,
451  // the inner allocation should be handled separately.
453  return false;
454  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
455  if (!interface)
456  return true;
457  for (auto res : op->getResults()) {
458  if (auto effect =
459  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
460  if (isa<SideEffects::AutomaticAllocationScopeResource>(
461  effect->getResource()))
462  return true;
463  }
464  }
465  return false;
466 }
467 
468 /// Return whether this op is the last non terminating op
469 /// in a region. That is to say, it is in a one-block region
470 /// and is only followed by a terminator. This prevents
471 /// extending the lifetime of allocations.
473  return op->getNextNode() == op->getBlock()->getTerminator() &&
474  llvm::hasSingleElement(op->getParentRegion()->getBlocks());
475 }
476 
477 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
478 /// or it contains no allocation.
479 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
481 
482  LogicalResult matchAndRewrite(AllocaScopeOp op,
483  PatternRewriter &rewriter) const override {
484  bool hasPotentialAlloca =
485  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
486  if (alloc == op)
487  return WalkResult::advance();
489  return WalkResult::interrupt();
490  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
491  return WalkResult::skip();
492  return WalkResult::advance();
493  }).wasInterrupted();
494 
495  // If this contains no potential allocation, it is always legal to
496  // inline. Otherwise, consider two conditions:
497  if (hasPotentialAlloca) {
498  // If the parent isn't an allocation scope, or we are not the last
499  // non-terminator op in the parent, we will extend the lifetime.
500  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
501  return failure();
502  if (!lastNonTerminatorInRegion(op))
503  return failure();
504  }
505 
506  Block *block = &op.getRegion().front();
507  Operation *terminator = block->getTerminator();
508  ValueRange results = terminator->getOperands();
509  rewriter.inlineBlockBefore(block, op);
510  rewriter.replaceOp(op, results);
511  rewriter.eraseOp(terminator);
512  return success();
513  }
514 };
515 
516 /// Move allocations into an allocation scope, if it is legal to
517 /// move them (e.g. their operands are available at the location
518 /// the op would be moved to).
519 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
521 
522  LogicalResult matchAndRewrite(AllocaScopeOp op,
523  PatternRewriter &rewriter) const override {
524 
525  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
526  return failure();
527 
528  Operation *lastParentWithoutScope = op->getParentOp();
529 
530  if (!lastParentWithoutScope ||
531  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
532  return failure();
533 
534  // Only apply to if this is this last non-terminator
535  // op in the block (lest lifetime be extended) of a one
536  // block region
537  if (!lastNonTerminatorInRegion(op) ||
538  !lastNonTerminatorInRegion(lastParentWithoutScope))
539  return failure();
540 
541  while (!lastParentWithoutScope->getParentOp()
543  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
544  if (!lastParentWithoutScope ||
545  !lastNonTerminatorInRegion(lastParentWithoutScope))
546  return failure();
547  }
548  assert(lastParentWithoutScope->getParentOp()
550 
551  Region *containingRegion = nullptr;
552  for (auto &r : lastParentWithoutScope->getRegions()) {
553  if (r.isAncestor(op->getParentRegion())) {
554  assert(containingRegion == nullptr &&
555  "only one region can contain the op");
556  containingRegion = &r;
557  }
558  }
559  assert(containingRegion && "op must be contained in a region");
560 
561  SmallVector<Operation *> toHoist;
562  op->walk([&](Operation *alloc) {
564  return WalkResult::skip();
565 
566  // If any operand is not defined before the location of
567  // lastParentWithoutScope (i.e. where we would hoist to), skip.
568  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
569  return containingRegion->isAncestor(v.getParentRegion());
570  }))
571  return WalkResult::skip();
572  toHoist.push_back(alloc);
573  return WalkResult::advance();
574  });
575 
576  if (toHoist.empty())
577  return failure();
578  rewriter.setInsertionPoint(lastParentWithoutScope);
579  for (auto *op : toHoist) {
580  auto *cloned = rewriter.clone(*op);
581  rewriter.replaceOp(op, cloned->getResults());
582  }
583  return success();
584  }
585 };
586 
587 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
588  MLIRContext *context) {
589  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // AssumeAlignmentOp
594 //===----------------------------------------------------------------------===//
595 
596 LogicalResult AssumeAlignmentOp::verify() {
597  if (!llvm::isPowerOf2_32(getAlignment()))
598  return emitOpError("alignment must be power of 2");
599  return success();
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // CastOp
604 //===----------------------------------------------------------------------===//
605 
606 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
607  setNameFn(getResult(), "cast");
608 }
609 
610 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
611 /// source memref. This is useful to fold a memref.cast into a consuming op
612 /// and implement canonicalization patterns for ops in different dialects that
613 /// may consume the results of memref.cast operations. Such foldable memref.cast
614 /// operations are typically inserted as `view` and `subview` ops are
615 /// canonicalized, to preserve the type compatibility of their uses.
616 ///
617 /// Returns true when all conditions are met:
618 /// 1. source and result are ranked memrefs with strided semantics and same
619 /// element type and rank.
620 /// 2. each of the source's size, offset or stride has more static information
621 /// than the corresponding result's size, offset or stride.
622 ///
623 /// Example 1:
624 /// ```mlir
625 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
626 /// %2 = consumer %1 ... : memref<?x?xf32> ...
627 /// ```
628 ///
629 /// may fold into:
630 ///
631 /// ```mlir
632 /// %2 = consumer %0 ... : memref<8x16xf32> ...
633 /// ```
634 ///
635 /// Example 2:
636 /// ```
637 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
638 /// to memref<?x?xf32>
639 /// consumer %1 : memref<?x?xf32> ...
640 /// ```
641 ///
642 /// may fold into:
643 ///
644 /// ```
645 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
646 /// ```
647 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
648  MemRefType sourceType =
649  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
650  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
651 
652  // Requires ranked MemRefType.
653  if (!sourceType || !resultType)
654  return false;
655 
656  // Requires same elemental type.
657  if (sourceType.getElementType() != resultType.getElementType())
658  return false;
659 
660  // Requires same rank.
661  if (sourceType.getRank() != resultType.getRank())
662  return false;
663 
664  // Only fold casts between strided memref forms.
665  int64_t sourceOffset, resultOffset;
666  SmallVector<int64_t, 4> sourceStrides, resultStrides;
667  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
668  failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
669  return false;
670 
671  // If cast is towards more static sizes along any dimension, don't fold.
672  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
673  auto ss = std::get<0>(it), st = std::get<1>(it);
674  if (ss != st)
675  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
676  return false;
677  }
678 
679  // If cast is towards more static offset along any dimension, don't fold.
680  if (sourceOffset != resultOffset)
681  if (ShapedType::isDynamic(sourceOffset) &&
682  !ShapedType::isDynamic(resultOffset))
683  return false;
684 
685  // If cast is towards more static strides along any dimension, don't fold.
686  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
687  auto ss = std::get<0>(it), st = std::get<1>(it);
688  if (ss != st)
689  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
690  return false;
691  }
692 
693  return true;
694 }
695 
696 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
697  if (inputs.size() != 1 || outputs.size() != 1)
698  return false;
699  Type a = inputs.front(), b = outputs.front();
700  auto aT = llvm::dyn_cast<MemRefType>(a);
701  auto bT = llvm::dyn_cast<MemRefType>(b);
702 
703  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
704  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
705 
706  if (aT && bT) {
707  if (aT.getElementType() != bT.getElementType())
708  return false;
709  if (aT.getLayout() != bT.getLayout()) {
710  int64_t aOffset, bOffset;
711  SmallVector<int64_t, 4> aStrides, bStrides;
712  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
713  failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
714  aStrides.size() != bStrides.size())
715  return false;
716 
717  // Strides along a dimension/offset are compatible if the value in the
718  // source memref is static and the value in the target memref is the
719  // same. They are also compatible if either one is dynamic (see
720  // description of MemRefCastOp for details).
721  auto checkCompatible = [](int64_t a, int64_t b) {
722  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
723  };
724  if (!checkCompatible(aOffset, bOffset))
725  return false;
726  for (const auto &aStride : enumerate(aStrides))
727  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
728  return false;
729  }
730  if (aT.getMemorySpace() != bT.getMemorySpace())
731  return false;
732 
733  // They must have the same rank, and any specified dimensions must match.
734  if (aT.getRank() != bT.getRank())
735  return false;
736 
737  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
738  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
739  if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
740  aDim != bDim)
741  return false;
742  }
743  return true;
744  } else {
745  if (!aT && !uaT)
746  return false;
747  if (!bT && !ubT)
748  return false;
749  // Unranked to unranked casting is unsupported
750  if (uaT && ubT)
751  return false;
752 
753  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
754  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
755  if (aEltType != bEltType)
756  return false;
757 
758  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
759  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
760  return aMemSpace == bMemSpace;
761  }
762 
763  return false;
764 }
765 
766 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
767  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // CopyOp
772 //===----------------------------------------------------------------------===//
773 
774 namespace {
775 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
776 /// and element type, the cast can be skipped. Such CastOps only cast the layout
777 /// of the type.
778 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
780 
781  LogicalResult matchAndRewrite(CopyOp copyOp,
782  PatternRewriter &rewriter) const override {
783  bool modified = false;
784 
785  // Check source.
786  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
787  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
788  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
789 
790  if (fromType && toType) {
791  if (fromType.getShape() == toType.getShape() &&
792  fromType.getElementType() == toType.getElementType()) {
793  rewriter.modifyOpInPlace(copyOp, [&] {
794  copyOp.getSourceMutable().assign(castOp.getSource());
795  });
796  modified = true;
797  }
798  }
799  }
800 
801  // Check target.
802  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
803  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
804  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
805 
806  if (fromType && toType) {
807  if (fromType.getShape() == toType.getShape() &&
808  fromType.getElementType() == toType.getElementType()) {
809  rewriter.modifyOpInPlace(copyOp, [&] {
810  copyOp.getTargetMutable().assign(castOp.getSource());
811  });
812  modified = true;
813  }
814  }
815  }
816 
817  return success(modified);
818  }
819 };
820 
821 /// Fold memref.copy(%x, %x).
822 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
824 
825  LogicalResult matchAndRewrite(CopyOp copyOp,
826  PatternRewriter &rewriter) const override {
827  if (copyOp.getSource() != copyOp.getTarget())
828  return failure();
829 
830  rewriter.eraseOp(copyOp);
831  return success();
832  }
833 };
834 
835 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
837 
838  static bool isEmptyMemRef(BaseMemRefType type) {
839  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
840  }
841 
842  LogicalResult matchAndRewrite(CopyOp copyOp,
843  PatternRewriter &rewriter) const override {
844  if (isEmptyMemRef(copyOp.getSource().getType()) ||
845  isEmptyMemRef(copyOp.getTarget().getType())) {
846  rewriter.eraseOp(copyOp);
847  return success();
848  }
849 
850  return failure();
851  }
852 };
853 } // namespace
854 
855 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
856  MLIRContext *context) {
857  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
858 }
859 
860 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
862  /// copy(memrefcast) -> copy
863  bool folded = false;
864  Operation *op = *this;
865  for (OpOperand &operand : op->getOpOperands()) {
866  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
867  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
868  operand.set(castOp.getOperand());
869  folded = true;
870  }
871  }
872  return success(folded);
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // DeallocOp
877 //===----------------------------------------------------------------------===//
878 
879 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
881  /// dealloc(memrefcast) -> dealloc
882  return foldMemRefCast(*this);
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // DimOp
887 //===----------------------------------------------------------------------===//
888 
889 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
890  setNameFn(getResult(), "dim");
891 }
892 
893 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
894  int64_t index) {
895  auto loc = result.location;
896  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
897  build(builder, result, source, indexValue);
898 }
899 
900 std::optional<int64_t> DimOp::getConstantIndex() {
901  return getConstantIntValue(getIndex());
902 }
903 
904 Speculation::Speculatability DimOp::getSpeculatability() {
905  auto constantIndex = getConstantIndex();
906  if (!constantIndex)
908 
909  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
910  if (!rankedSourceType)
912 
913  if (rankedSourceType.getRank() <= constantIndex)
915 
917 }
918 
919 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
920  SetIntLatticeFn setResultRange) {
921  setResultRange(getResult(),
922  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
923 }
924 
925 /// Return a map with key being elements in `vals` and data being number of
926 /// occurences of it. Use std::map, since the `vals` here are strides and the
927 /// dynamic stride value is the same as the tombstone value for
928 /// `DenseMap<int64_t>`.
929 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
930  std::map<int64_t, unsigned> numOccurences;
931  for (auto val : vals)
932  numOccurences[val]++;
933  return numOccurences;
934 }
935 
936 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
937 /// to be a subset of `originalType` with some `1` entries erased, return the
938 /// set of indices that specifies which of the entries of `originalShape` are
939 /// dropped to obtain `reducedShape`.
940 /// This accounts for cases where there are multiple unit-dims, but only a
941 /// subset of those are dropped. For MemRefTypes these can be disambiguated
942 /// using the strides. If a dimension is dropped the stride must be dropped too.
943 static FailureOr<llvm::SmallBitVector>
944 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
945  ArrayRef<OpFoldResult> sizes) {
946  llvm::SmallBitVector unusedDims(originalType.getRank());
947  if (originalType.getRank() == reducedType.getRank())
948  return unusedDims;
949 
950  for (const auto &dim : llvm::enumerate(sizes))
951  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
952  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
953  unusedDims.set(dim.index());
954 
955  // Early exit for the case where the number of unused dims matches the number
956  // of ranks reduced.
957  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
958  originalType.getRank())
959  return unusedDims;
960 
961  SmallVector<int64_t> originalStrides, candidateStrides;
962  int64_t originalOffset, candidateOffset;
963  if (failed(
964  originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
965  failed(
966  reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
967  return failure();
968 
969  // For memrefs, a dimension is truly dropped if its corresponding stride is
970  // also dropped. This is particularly important when more than one of the dims
971  // is 1. Track the number of occurences of the strides in the original type
972  // and the candidate type. For each unused dim that stride should not be
973  // present in the candidate type. Note that there could be multiple dimensions
974  // that have the same size. We dont need to exactly figure out which dim
975  // corresponds to which stride, we just need to verify that the number of
976  // reptitions of a stride in the original + number of unused dims with that
977  // stride == number of repititions of a stride in the candidate.
978  std::map<int64_t, unsigned> currUnaccountedStrides =
979  getNumOccurences(originalStrides);
980  std::map<int64_t, unsigned> candidateStridesNumOccurences =
981  getNumOccurences(candidateStrides);
982  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
983  if (!unusedDims.test(dim))
984  continue;
985  int64_t originalStride = originalStrides[dim];
986  if (currUnaccountedStrides[originalStride] >
987  candidateStridesNumOccurences[originalStride]) {
988  // This dim can be treated as dropped.
989  currUnaccountedStrides[originalStride]--;
990  continue;
991  }
992  if (currUnaccountedStrides[originalStride] ==
993  candidateStridesNumOccurences[originalStride]) {
994  // The stride for this is not dropped. Keep as is.
995  unusedDims.reset(dim);
996  continue;
997  }
998  if (currUnaccountedStrides[originalStride] <
999  candidateStridesNumOccurences[originalStride]) {
1000  // This should never happen. Cant have a stride in the reduced rank type
1001  // that wasnt in the original one.
1002  return failure();
1003  }
1004  }
1005 
1006  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1007  originalType.getRank())
1008  return failure();
1009  return unusedDims;
1010 }
1011 
1012 llvm::SmallBitVector SubViewOp::getDroppedDims() {
1013  MemRefType sourceType = getSourceType();
1014  MemRefType resultType = getType();
1015  FailureOr<llvm::SmallBitVector> unusedDims =
1016  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1017  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1018  return *unusedDims;
1019 }
1020 
1021 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1022  // All forms of folding require a known index.
1023  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1024  if (!index)
1025  return {};
1026 
1027  // Folding for unranked types (UnrankedMemRefType) is not supported.
1028  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1029  if (!memrefType)
1030  return {};
1031 
1032  // Out of bound indices produce undefined behavior but are still valid IR.
1033  // Don't choke on them.
1034  int64_t indexVal = index.getInt();
1035  if (indexVal < 0 || indexVal >= memrefType.getRank())
1036  return {};
1037 
1038  // Fold if the shape extent along the given index is known.
1039  if (!memrefType.isDynamicDim(index.getInt())) {
1040  Builder builder(getContext());
1041  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1042  }
1043 
1044  // The size at the given index is now known to be a dynamic size.
1045  unsigned unsignedIndex = index.getValue().getZExtValue();
1046 
1047  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1048  Operation *definingOp = getSource().getDefiningOp();
1049 
1050  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1051  return *(alloc.getDynamicSizes().begin() +
1052  memrefType.getDynamicDimIndex(unsignedIndex));
1053 
1054  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1055  return *(alloca.getDynamicSizes().begin() +
1056  memrefType.getDynamicDimIndex(unsignedIndex));
1057 
1058  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1059  return *(view.getDynamicSizes().begin() +
1060  memrefType.getDynamicDimIndex(unsignedIndex));
1061 
1062  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1063  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1064  unsigned resultIndex = 0;
1065  unsigned sourceRank = subview.getSourceType().getRank();
1066  unsigned sourceIndex = 0;
1067  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1068  if (unusedDims.test(i))
1069  continue;
1070  if (resultIndex == unsignedIndex) {
1071  sourceIndex = i;
1072  break;
1073  }
1074  resultIndex++;
1075  }
1076  assert(subview.isDynamicSize(sourceIndex) &&
1077  "expected dynamic subview size");
1078  return subview.getDynamicSize(sourceIndex);
1079  }
1080 
1081  if (auto sizeInterface =
1082  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1083  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1084  "Expected dynamic subview size");
1085  return sizeInterface.getDynamicSize(unsignedIndex);
1086  }
1087 
1088  // dim(memrefcast) -> dim
1089  if (succeeded(foldMemRefCast(*this)))
1090  return getResult();
1091 
1092  return {};
1093 }
1094 
1095 namespace {
1096 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1097 /// operand.
1098 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1100 
1101  LogicalResult matchAndRewrite(DimOp dim,
1102  PatternRewriter &rewriter) const override {
1103  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1104 
1105  if (!reshape)
1106  return rewriter.notifyMatchFailure(
1107  dim, "Dim op is not defined by a reshape op.");
1108 
1109  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1110  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1111  // cheaply check that either of the following conditions hold:
1112  // 1. dim.getIndex() is defined in the same block as reshape but before
1113  // reshape.
1114  // 2. dim.getIndex() is defined in a parent block of
1115  // reshape.
1116 
1117  // Check condition 1
1118  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1119  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1120  if (reshape->isBeforeInBlock(definingOp)) {
1121  return rewriter.notifyMatchFailure(
1122  dim,
1123  "dim.getIndex is not defined before reshape in the same block.");
1124  }
1125  } // else dim.getIndex is a block argument to reshape->getBlock and
1126  // dominates reshape
1127  } // Check condition 2
1128  else if (dim->getBlock() != reshape->getBlock() &&
1129  !dim.getIndex().getParentRegion()->isProperAncestor(
1130  reshape->getParentRegion())) {
1131  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1132  // already know dim.getIndex() dominates reshape without calling
1133  // `isProperAncestor`
1134  return rewriter.notifyMatchFailure(
1135  dim, "dim.getIndex does not dominate reshape.");
1136  }
1137 
1138  // Place the load directly after the reshape to ensure that the shape memref
1139  // was not mutated.
1140  rewriter.setInsertionPointAfter(reshape);
1141  Location loc = dim.getLoc();
1142  Value load =
1143  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1144  if (load.getType() != dim.getType())
1145  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1146  rewriter.replaceOp(dim, load);
1147  return success();
1148  }
1149 };
1150 
1151 } // namespace
1152 
1153 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1154  MLIRContext *context) {
1155  results.add<DimOfMemRefReshape>(context);
1156 }
1157 
1158 // ---------------------------------------------------------------------------
1159 // DmaStartOp
1160 // ---------------------------------------------------------------------------
1161 
1162 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1163  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1164  ValueRange destIndices, Value numElements,
1165  Value tagMemRef, ValueRange tagIndices, Value stride,
1166  Value elementsPerStride) {
1167  result.addOperands(srcMemRef);
1168  result.addOperands(srcIndices);
1169  result.addOperands(destMemRef);
1170  result.addOperands(destIndices);
1171  result.addOperands({numElements, tagMemRef});
1172  result.addOperands(tagIndices);
1173  if (stride)
1174  result.addOperands({stride, elementsPerStride});
1175 }
1176 
1178  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1179  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1180  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1181  if (isStrided())
1182  p << ", " << getStride() << ", " << getNumElementsPerStride();
1183 
1184  p.printOptionalAttrDict((*this)->getAttrs());
1185  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1186  << ", " << getTagMemRef().getType();
1187 }
1188 
1189 // Parse DmaStartOp.
1190 // Ex:
1191 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1192 // %tag[%index], %stride, %num_elt_per_stride :
1193 // : memref<3076 x f32, 0>,
1194 // memref<1024 x f32, 2>,
1195 // memref<1 x i32>
1196 //
1197 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1198  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1200  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1202  OpAsmParser::UnresolvedOperand numElementsInfo;
1203  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1206 
1207  SmallVector<Type, 3> types;
1208  auto indexType = parser.getBuilder().getIndexType();
1209 
1210  // Parse and resolve the following list of operands:
1211  // *) source memref followed by its indices (in square brackets).
1212  // *) destination memref followed by its indices (in square brackets).
1213  // *) dma size in KiB.
1214  if (parser.parseOperand(srcMemRefInfo) ||
1215  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1216  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1217  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1218  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1219  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1220  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1221  return failure();
1222 
1223  // Parse optional stride and elements per stride.
1224  if (parser.parseTrailingOperandList(strideInfo))
1225  return failure();
1226 
1227  bool isStrided = strideInfo.size() == 2;
1228  if (!strideInfo.empty() && !isStrided) {
1229  return parser.emitError(parser.getNameLoc(),
1230  "expected two stride related operands");
1231  }
1232 
1233  if (parser.parseColonTypeList(types))
1234  return failure();
1235  if (types.size() != 3)
1236  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1237 
1238  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1239  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1240  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1241  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1242  // size should be an index.
1243  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1244  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1245  // tag indices should be index.
1246  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1247  return failure();
1248 
1249  if (isStrided) {
1250  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1251  return failure();
1252  }
1253 
1254  return success();
1255 }
1256 
1257 LogicalResult DmaStartOp::verify() {
1258  unsigned numOperands = getNumOperands();
1259 
1260  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1261  // the number of elements.
1262  if (numOperands < 4)
1263  return emitOpError("expected at least 4 operands");
1264 
1265  // Check types of operands. The order of these calls is important: the later
1266  // calls rely on some type properties to compute the operand position.
1267  // 1. Source memref.
1268  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1269  return emitOpError("expected source to be of memref type");
1270  if (numOperands < getSrcMemRefRank() + 4)
1271  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1272  << " operands";
1273  if (!getSrcIndices().empty() &&
1274  !llvm::all_of(getSrcIndices().getTypes(),
1275  [](Type t) { return t.isIndex(); }))
1276  return emitOpError("expected source indices to be of index type");
1277 
1278  // 2. Destination memref.
1279  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1280  return emitOpError("expected destination to be of memref type");
1281  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1282  if (numOperands < numExpectedOperands)
1283  return emitOpError() << "expected at least " << numExpectedOperands
1284  << " operands";
1285  if (!getDstIndices().empty() &&
1286  !llvm::all_of(getDstIndices().getTypes(),
1287  [](Type t) { return t.isIndex(); }))
1288  return emitOpError("expected destination indices to be of index type");
1289 
1290  // 3. Number of elements.
1291  if (!getNumElements().getType().isIndex())
1292  return emitOpError("expected num elements to be of index type");
1293 
1294  // 4. Tag memref.
1295  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1296  return emitOpError("expected tag to be of memref type");
1297  numExpectedOperands += getTagMemRefRank();
1298  if (numOperands < numExpectedOperands)
1299  return emitOpError() << "expected at least " << numExpectedOperands
1300  << " operands";
1301  if (!getTagIndices().empty() &&
1302  !llvm::all_of(getTagIndices().getTypes(),
1303  [](Type t) { return t.isIndex(); }))
1304  return emitOpError("expected tag indices to be of index type");
1305 
1306  // Optional stride-related operands must be either both present or both
1307  // absent.
1308  if (numOperands != numExpectedOperands &&
1309  numOperands != numExpectedOperands + 2)
1310  return emitOpError("incorrect number of operands");
1311 
1312  // 5. Strides.
1313  if (isStrided()) {
1314  if (!getStride().getType().isIndex() ||
1315  !getNumElementsPerStride().getType().isIndex())
1316  return emitOpError(
1317  "expected stride and num elements per stride to be of type index");
1318  }
1319 
1320  return success();
1321 }
1322 
1323 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1324  SmallVectorImpl<OpFoldResult> &results) {
1325  /// dma_start(memrefcast) -> dma_start
1326  return foldMemRefCast(*this);
1327 }
1328 
1329 // ---------------------------------------------------------------------------
1330 // DmaWaitOp
1331 // ---------------------------------------------------------------------------
1332 
1333 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1334  SmallVectorImpl<OpFoldResult> &results) {
1335  /// dma_wait(memrefcast) -> dma_wait
1336  return foldMemRefCast(*this);
1337 }
1338 
1339 LogicalResult DmaWaitOp::verify() {
1340  // Check that the number of tag indices matches the tagMemRef rank.
1341  unsigned numTagIndices = getTagIndices().size();
1342  unsigned tagMemRefRank = getTagMemRefRank();
1343  if (numTagIndices != tagMemRefRank)
1344  return emitOpError() << "expected tagIndices to have the same number of "
1345  "elements as the tagMemRef rank, expected "
1346  << tagMemRefRank << ", but got " << numTagIndices;
1347  return success();
1348 }
1349 
1350 //===----------------------------------------------------------------------===//
1351 // ExtractAlignedPointerAsIndexOp
1352 //===----------------------------------------------------------------------===//
1353 
1354 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1355  function_ref<void(Value, StringRef)> setNameFn) {
1356  setNameFn(getResult(), "intptr");
1357 }
1358 
1359 //===----------------------------------------------------------------------===//
1360 // ExtractStridedMetadataOp
1361 //===----------------------------------------------------------------------===//
1362 
1363 /// The number and type of the results are inferred from the
1364 /// shape of the source.
1365 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1366  MLIRContext *context, std::optional<Location> location,
1367  ExtractStridedMetadataOp::Adaptor adaptor,
1368  SmallVectorImpl<Type> &inferredReturnTypes) {
1369  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1370  if (!sourceType)
1371  return failure();
1372 
1373  unsigned sourceRank = sourceType.getRank();
1374  IndexType indexType = IndexType::get(context);
1375  auto memrefType =
1376  MemRefType::get({}, sourceType.getElementType(),
1377  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1378  // Base.
1379  inferredReturnTypes.push_back(memrefType);
1380  // Offset.
1381  inferredReturnTypes.push_back(indexType);
1382  // Sizes and strides.
1383  for (unsigned i = 0; i < sourceRank * 2; ++i)
1384  inferredReturnTypes.push_back(indexType);
1385  return success();
1386 }
1387 
1388 void ExtractStridedMetadataOp::getAsmResultNames(
1389  function_ref<void(Value, StringRef)> setNameFn) {
1390  setNameFn(getBaseBuffer(), "base_buffer");
1391  setNameFn(getOffset(), "offset");
1392  // For multi-result to work properly with pretty names and packed syntax `x:3`
1393  // we can only give a pretty name to the first value in the pack.
1394  if (!getSizes().empty()) {
1395  setNameFn(getSizes().front(), "sizes");
1396  setNameFn(getStrides().front(), "strides");
1397  }
1398 }
1399 
1400 /// Helper function to perform the replacement of all constant uses of `values`
1401 /// by a materialized constant extracted from `maybeConstants`.
1402 /// `values` and `maybeConstants` are expected to have the same size.
1403 template <typename Container>
1404 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1405  Container values,
1406  ArrayRef<OpFoldResult> maybeConstants) {
1407  assert(values.size() == maybeConstants.size() &&
1408  " expected values and maybeConstants of the same size");
1409  bool atLeastOneReplacement = false;
1410  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1411  // Don't materialize a constant if there are no uses: this would indice
1412  // infinite loops in the driver.
1413  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1414  continue;
1415  assert(isa<Attribute>(maybeConstant) &&
1416  "The constified value should be either unchanged (i.e., == result) "
1417  "or a constant");
1418  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1419  loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1420  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1421  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1422  // yet.
1423  op->replaceUsesOfWith(result, constantVal);
1424  atLeastOneReplacement = true;
1425  }
1426  }
1427  return atLeastOneReplacement;
1428 }
1429 
1430 LogicalResult
1431 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1432  SmallVectorImpl<OpFoldResult> &results) {
1433  OpBuilder builder(*this);
1434 
1435  bool atLeastOneReplacement = replaceConstantUsesOf(
1436  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1437  getConstifiedMixedOffset());
1438  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1439  getConstifiedMixedSizes());
1440  atLeastOneReplacement |= replaceConstantUsesOf(
1441  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1442 
1443  return success(atLeastOneReplacement);
1444 }
1445 
1446 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1447  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1448  constifyIndexValues(values, getSource().getType(), getContext(),
1449  getConstantSizes, ShapedType::isDynamic);
1450  return values;
1451 }
1452 
1454 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1455  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1456  constifyIndexValues(values, getSource().getType(), getContext(),
1457  getConstantStrides, ShapedType::isDynamic);
1458  return values;
1459 }
1460 
1461 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1462  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1463  SmallVector<OpFoldResult> values(1, offsetOfr);
1464  constifyIndexValues(values, getSource().getType(), getContext(),
1465  getConstantOffset, ShapedType::isDynamic);
1466  return values[0];
1467 }
1468 
1469 //===----------------------------------------------------------------------===//
1470 // GenericAtomicRMWOp
1471 //===----------------------------------------------------------------------===//
1472 
1473 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1474  Value memref, ValueRange ivs) {
1475  OpBuilder::InsertionGuard g(builder);
1476  result.addOperands(memref);
1477  result.addOperands(ivs);
1478 
1479  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1480  Type elementType = memrefType.getElementType();
1481  result.addTypes(elementType);
1482 
1483  Region *bodyRegion = result.addRegion();
1484  builder.createBlock(bodyRegion);
1485  bodyRegion->addArgument(elementType, memref.getLoc());
1486  }
1487 }
1488 
1489 LogicalResult GenericAtomicRMWOp::verify() {
1490  auto &body = getRegion();
1491  if (body.getNumArguments() != 1)
1492  return emitOpError("expected single number of entry block arguments");
1493 
1494  if (getResult().getType() != body.getArgument(0).getType())
1495  return emitOpError("expected block argument of the same type result type");
1496 
1497  bool hasSideEffects =
1498  body.walk([&](Operation *nestedOp) {
1499  if (isMemoryEffectFree(nestedOp))
1500  return WalkResult::advance();
1501  nestedOp->emitError(
1502  "body of 'memref.generic_atomic_rmw' should contain "
1503  "only operations with no side effects");
1504  return WalkResult::interrupt();
1505  })
1506  .wasInterrupted();
1507  return hasSideEffects ? failure() : success();
1508 }
1509 
1510 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1511  OperationState &result) {
1513  Type memrefType;
1515 
1516  Type indexType = parser.getBuilder().getIndexType();
1517  if (parser.parseOperand(memref) ||
1519  parser.parseColonType(memrefType) ||
1520  parser.resolveOperand(memref, memrefType, result.operands) ||
1521  parser.resolveOperands(ivs, indexType, result.operands))
1522  return failure();
1523 
1524  Region *body = result.addRegion();
1525  if (parser.parseRegion(*body, {}) ||
1526  parser.parseOptionalAttrDict(result.attributes))
1527  return failure();
1528  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1529  return success();
1530 }
1531 
1533  p << ' ' << getMemref() << "[" << getIndices()
1534  << "] : " << getMemref().getType() << ' ';
1535  p.printRegion(getRegion());
1536  p.printOptionalAttrDict((*this)->getAttrs());
1537 }
1538 
1539 //===----------------------------------------------------------------------===//
1540 // AtomicYieldOp
1541 //===----------------------------------------------------------------------===//
1542 
1543 LogicalResult AtomicYieldOp::verify() {
1544  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1545  Type resultType = getResult().getType();
1546  if (parentType != resultType)
1547  return emitOpError() << "types mismatch between yield op: " << resultType
1548  << " and its parent: " << parentType;
1549  return success();
1550 }
1551 
1552 //===----------------------------------------------------------------------===//
1553 // GlobalOp
1554 //===----------------------------------------------------------------------===//
1555 
1557  TypeAttr type,
1558  Attribute initialValue) {
1559  p << type;
1560  if (!op.isExternal()) {
1561  p << " = ";
1562  if (op.isUninitialized())
1563  p << "uninitialized";
1564  else
1565  p.printAttributeWithoutType(initialValue);
1566  }
1567 }
1568 
1569 static ParseResult
1571  Attribute &initialValue) {
1572  Type type;
1573  if (parser.parseType(type))
1574  return failure();
1575 
1576  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1577  if (!memrefType || !memrefType.hasStaticShape())
1578  return parser.emitError(parser.getNameLoc())
1579  << "type should be static shaped memref, but got " << type;
1580  typeAttr = TypeAttr::get(type);
1581 
1582  if (parser.parseOptionalEqual())
1583  return success();
1584 
1585  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1586  initialValue = UnitAttr::get(parser.getContext());
1587  return success();
1588  }
1589 
1590  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1591  if (parser.parseAttribute(initialValue, tensorType))
1592  return failure();
1593  if (!llvm::isa<ElementsAttr>(initialValue))
1594  return parser.emitError(parser.getNameLoc())
1595  << "initial value should be a unit or elements attribute";
1596  return success();
1597 }
1598 
1599 LogicalResult GlobalOp::verify() {
1600  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1601  if (!memrefType || !memrefType.hasStaticShape())
1602  return emitOpError("type should be static shaped memref, but got ")
1603  << getType();
1604 
1605  // Verify that the initial value, if present, is either a unit attribute or
1606  // an elements attribute.
1607  if (getInitialValue().has_value()) {
1608  Attribute initValue = getInitialValue().value();
1609  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1610  return emitOpError("initial value should be a unit or elements "
1611  "attribute, but got ")
1612  << initValue;
1613 
1614  // Check that the type of the initial value is compatible with the type of
1615  // the global variable.
1616  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1617  Type initType = elementsAttr.getType();
1618  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1619  if (initType != tensorType)
1620  return emitOpError("initial value expected to be of type ")
1621  << tensorType << ", but was of type " << initType;
1622  }
1623  }
1624 
1625  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1626  uint64_t alignment = *alignAttr;
1627 
1628  if (!llvm::isPowerOf2_64(alignment))
1629  return emitError() << "alignment attribute value " << alignment
1630  << " is not a power of 2";
1631  }
1632 
1633  // TODO: verify visibility for declarations.
1634  return success();
1635 }
1636 
1637 ElementsAttr GlobalOp::getConstantInitValue() {
1638  auto initVal = getInitialValue();
1639  if (getConstant() && initVal.has_value())
1640  return llvm::cast<ElementsAttr>(initVal.value());
1641  return {};
1642 }
1643 
1644 //===----------------------------------------------------------------------===//
1645 // GetGlobalOp
1646 //===----------------------------------------------------------------------===//
1647 
1648 LogicalResult
1649 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1650  // Verify that the result type is same as the type of the referenced
1651  // memref.global op.
1652  auto global =
1653  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1654  if (!global)
1655  return emitOpError("'")
1656  << getName() << "' does not reference a valid global memref";
1657 
1658  Type resultType = getResult().getType();
1659  if (global.getType() != resultType)
1660  return emitOpError("result type ")
1661  << resultType << " does not match type " << global.getType()
1662  << " of the global memref @" << getName();
1663  return success();
1664 }
1665 
1666 //===----------------------------------------------------------------------===//
1667 // LoadOp
1668 //===----------------------------------------------------------------------===//
1669 
1670 LogicalResult LoadOp::verify() {
1671  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1672  return emitOpError("incorrect number of indices for load, expected ")
1673  << getMemRefType().getRank() << " but got " << getIndices().size();
1674  }
1675  return success();
1676 }
1677 
1678 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1679  /// load(memrefcast) -> load
1680  if (succeeded(foldMemRefCast(*this)))
1681  return getResult();
1682  return OpFoldResult();
1683 }
1684 
1685 //===----------------------------------------------------------------------===//
1686 // MemorySpaceCastOp
1687 //===----------------------------------------------------------------------===//
1688 
1689 void MemorySpaceCastOp::getAsmResultNames(
1690  function_ref<void(Value, StringRef)> setNameFn) {
1691  setNameFn(getResult(), "memspacecast");
1692 }
1693 
1694 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1695  if (inputs.size() != 1 || outputs.size() != 1)
1696  return false;
1697  Type a = inputs.front(), b = outputs.front();
1698  auto aT = llvm::dyn_cast<MemRefType>(a);
1699  auto bT = llvm::dyn_cast<MemRefType>(b);
1700 
1701  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1702  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1703 
1704  if (aT && bT) {
1705  if (aT.getElementType() != bT.getElementType())
1706  return false;
1707  if (aT.getLayout() != bT.getLayout())
1708  return false;
1709  if (aT.getShape() != bT.getShape())
1710  return false;
1711  return true;
1712  }
1713  if (uaT && ubT) {
1714  return uaT.getElementType() == ubT.getElementType();
1715  }
1716  return false;
1717 }
1718 
1719 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1720  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1721  // t2)
1722  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1723  getSourceMutable().assign(parentCast.getSource());
1724  return getResult();
1725  }
1726  return Value{};
1727 }
1728 
1729 //===----------------------------------------------------------------------===//
1730 // PrefetchOp
1731 //===----------------------------------------------------------------------===//
1732 
1734  p << " " << getMemref() << '[';
1736  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1737  p << ", locality<" << getLocalityHint();
1738  p << ">, " << (getIsDataCache() ? "data" : "instr");
1740  (*this)->getAttrs(),
1741  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1742  p << " : " << getMemRefType();
1743 }
1744 
1745 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1746  OpAsmParser::UnresolvedOperand memrefInfo;
1748  IntegerAttr localityHint;
1749  MemRefType type;
1750  StringRef readOrWrite, cacheType;
1751 
1752  auto indexTy = parser.getBuilder().getIndexType();
1753  auto i32Type = parser.getBuilder().getIntegerType(32);
1754  if (parser.parseOperand(memrefInfo) ||
1755  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1756  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1757  parser.parseComma() || parser.parseKeyword("locality") ||
1758  parser.parseLess() ||
1759  parser.parseAttribute(localityHint, i32Type, "localityHint",
1760  result.attributes) ||
1761  parser.parseGreater() || parser.parseComma() ||
1762  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1763  parser.resolveOperand(memrefInfo, type, result.operands) ||
1764  parser.resolveOperands(indexInfo, indexTy, result.operands))
1765  return failure();
1766 
1767  if (readOrWrite != "read" && readOrWrite != "write")
1768  return parser.emitError(parser.getNameLoc(),
1769  "rw specifier has to be 'read' or 'write'");
1770  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1771  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1772 
1773  if (cacheType != "data" && cacheType != "instr")
1774  return parser.emitError(parser.getNameLoc(),
1775  "cache type has to be 'data' or 'instr'");
1776 
1777  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1778  parser.getBuilder().getBoolAttr(cacheType == "data"));
1779 
1780  return success();
1781 }
1782 
1783 LogicalResult PrefetchOp::verify() {
1784  if (getNumOperands() != 1 + getMemRefType().getRank())
1785  return emitOpError("too few indices");
1786 
1787  return success();
1788 }
1789 
1790 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1791  SmallVectorImpl<OpFoldResult> &results) {
1792  // prefetch(memrefcast) -> prefetch
1793  return foldMemRefCast(*this);
1794 }
1795 
1796 //===----------------------------------------------------------------------===//
1797 // RankOp
1798 //===----------------------------------------------------------------------===//
1799 
1800 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1801  // Constant fold rank when the rank of the operand is known.
1802  auto type = getOperand().getType();
1803  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1804  if (shapedType && shapedType.hasRank())
1805  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1806  return IntegerAttr();
1807 }
1808 
1809 //===----------------------------------------------------------------------===//
1810 // ReinterpretCastOp
1811 //===----------------------------------------------------------------------===//
1812 
1813 void ReinterpretCastOp::getAsmResultNames(
1814  function_ref<void(Value, StringRef)> setNameFn) {
1815  setNameFn(getResult(), "reinterpret_cast");
1816 }
1817 
1818 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1819 /// `staticSizes` and `staticStrides` are automatically filled with
1820 /// source-memref-rank sentinel values that encode dynamic entries.
1821 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1822  MemRefType resultType, Value source,
1823  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1824  ArrayRef<OpFoldResult> strides,
1825  ArrayRef<NamedAttribute> attrs) {
1826  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1827  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1828  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1829  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1830  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1831  result.addAttributes(attrs);
1832  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1833  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1834  b.getDenseI64ArrayAttr(staticSizes),
1835  b.getDenseI64ArrayAttr(staticStrides));
1836 }
1837 
1838 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1839  Value source, OpFoldResult offset,
1840  ArrayRef<OpFoldResult> sizes,
1841  ArrayRef<OpFoldResult> strides,
1842  ArrayRef<NamedAttribute> attrs) {
1843  auto sourceType = cast<BaseMemRefType>(source.getType());
1844  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1845  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1846  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1847  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1848  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1849  auto stridedLayout = StridedLayoutAttr::get(
1850  b.getContext(), staticOffsets.front(), staticStrides);
1851  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1852  stridedLayout, sourceType.getMemorySpace());
1853  build(b, result, resultType, source, offset, sizes, strides, attrs);
1854 }
1855 
1856 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1857  MemRefType resultType, Value source,
1858  int64_t offset, ArrayRef<int64_t> sizes,
1859  ArrayRef<int64_t> strides,
1860  ArrayRef<NamedAttribute> attrs) {
1861  SmallVector<OpFoldResult> sizeValues =
1862  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1863  return b.getI64IntegerAttr(v);
1864  }));
1865  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1866  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1867  return b.getI64IntegerAttr(v);
1868  }));
1869  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1870  strideValues, attrs);
1871 }
1872 
1873 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1874  MemRefType resultType, Value source, Value offset,
1875  ValueRange sizes, ValueRange strides,
1876  ArrayRef<NamedAttribute> attrs) {
1877  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1878  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1879  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1880  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1881  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1882 }
1883 
1884 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1885 // completed automatically, like we have for subview and extract_slice.
1886 LogicalResult ReinterpretCastOp::verify() {
1887  // The source and result memrefs should be in the same memory space.
1888  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1889  auto resultType = llvm::cast<MemRefType>(getType());
1890  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1891  return emitError("different memory spaces specified for source type ")
1892  << srcType << " and result memref type " << resultType;
1893  if (srcType.getElementType() != resultType.getElementType())
1894  return emitError("different element types specified for source type ")
1895  << srcType << " and result memref type " << resultType;
1896 
1897  // Match sizes in result memref type and in static_sizes attribute.
1898  for (auto [idx, resultSize, expectedSize] :
1899  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1900  if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
1901  return emitError("expected result type with size = ")
1902  << (ShapedType::isDynamic(expectedSize)
1903  ? std::string("dynamic")
1904  : std::to_string(expectedSize))
1905  << " instead of " << resultSize << " in dim = " << idx;
1906  }
1907 
1908  // Match offset and strides in static_offset and static_strides attributes. If
1909  // result memref type has no affine map specified, this will assume an
1910  // identity layout.
1911  int64_t resultOffset;
1912  SmallVector<int64_t, 4> resultStrides;
1913  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1914  return emitError("expected result type to have strided layout but found ")
1915  << resultType;
1916 
1917  // Match offset in result memref type and in static_offsets attribute.
1918  int64_t expectedOffset = getStaticOffsets().front();
1919  if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
1920  return emitError("expected result type with offset = ")
1921  << (ShapedType::isDynamic(expectedOffset)
1922  ? std::string("dynamic")
1923  : std::to_string(expectedOffset))
1924  << " instead of " << resultOffset;
1925 
1926  // Match strides in result memref type and in static_strides attribute.
1927  for (auto [idx, resultStride, expectedStride] :
1928  llvm::enumerate(resultStrides, getStaticStrides())) {
1929  if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
1930  return emitError("expected result type with stride = ")
1931  << (ShapedType::isDynamic(expectedStride)
1932  ? std::string("dynamic")
1933  : std::to_string(expectedStride))
1934  << " instead of " << resultStride << " in dim = " << idx;
1935  }
1936 
1937  return success();
1938 }
1939 
1940 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1941  Value src = getSource();
1942  auto getPrevSrc = [&]() -> Value {
1943  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1944  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1945  return prev.getSource();
1946 
1947  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1948  if (auto prev = src.getDefiningOp<CastOp>())
1949  return prev.getSource();
1950 
1951  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1952  // are 0.
1953  if (auto prev = src.getDefiningOp<SubViewOp>())
1954  if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1955  return isConstantIntValue(val, 0);
1956  }))
1957  return prev.getSource();
1958 
1959  return nullptr;
1960  };
1961 
1962  if (auto prevSrc = getPrevSrc()) {
1963  getSourceMutable().assign(prevSrc);
1964  return getResult();
1965  }
1966 
1967  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1968  if (!ShapedType::isDynamicShape(getType().getShape()) &&
1969  src.getType() == getType() && getStaticOffsets().front() == 0) {
1970  return src;
1971  }
1972 
1973  return nullptr;
1974 }
1975 
1976 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1979  ShapedType::isDynamic);
1980  return values;
1981 }
1982 
1983 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1984  SmallVector<OpFoldResult> values = getMixedStrides();
1986  ShapedType::isDynamic);
1987  return values;
1988 }
1989 
1990 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1991  SmallVector<OpFoldResult> values = getMixedOffsets();
1992  assert(values.size() == 1 &&
1993  "reinterpret_cast must have one and only one offset");
1995  ShapedType::isDynamic);
1996  return values[0];
1997 }
1998 
1999 namespace {
2000 /// Replace the sequence:
2001 /// ```
2002 /// base, offset, sizes, strides = extract_strided_metadata src
2003 /// dst = reinterpret_cast base to offset, sizes, strides
2004 /// ```
2005 /// With
2006 ///
2007 /// ```
2008 /// dst = memref.cast src
2009 /// ```
2010 ///
2011 /// Note: The cast operation is only inserted when the type of dst and src
2012 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
2013 ///
2014 /// This pattern also matches when the offset, sizes, and strides don't come
2015 /// directly from the `extract_strided_metadata`'s results but it can be
2016 /// statically proven that they would hold the same values.
2017 ///
2018 /// For instance, the following sequence would be replaced:
2019 /// ```
2020 /// base, offset, sizes, strides =
2021 /// extract_strided_metadata memref : memref<3x4xty>
2022 /// dst = reinterpret_cast base to 0, [3, 4], strides
2023 /// ```
2024 /// Because we know (thanks to the type of the input memref) that variable
2025 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
2026 ///
2027 /// Similarly, the following sequence would be replaced:
2028 /// ```
2029 /// c0 = arith.constant 0
2030 /// c4 = arith.constant 4
2031 /// base, offset, sizes, strides =
2032 /// extract_strided_metadata memref : memref<3x4xty>
2033 /// dst = reinterpret_cast base to c0, [3, c4], strides
2034 /// ```
2035 /// Because we know that `offset`and `c0` will hold 0
2036 /// and `c4` will hold 4.
2037 struct ReinterpretCastOpExtractStridedMetadataFolder
2038  : public OpRewritePattern<ReinterpretCastOp> {
2039 public:
2041 
2042  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2043  PatternRewriter &rewriter) const override {
2044  auto extractStridedMetadata =
2045  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2046  if (!extractStridedMetadata)
2047  return failure();
2048  // Check if the reinterpret cast reconstructs a memref with the exact same
2049  // properties as the extract strided metadata.
2050 
2051  // First, check that the strides are the same.
2052  SmallVector<OpFoldResult> extractStridesOfr =
2053  extractStridedMetadata.getConstifiedMixedStrides();
2054  SmallVector<OpFoldResult> reinterpretStridesOfr =
2055  op.getConstifiedMixedStrides();
2056  if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2057  return failure();
2058 
2059  unsigned rank = op.getType().getRank();
2060  for (unsigned i = 0; i < rank; ++i) {
2061  if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2062  return failure();
2063  }
2064 
2065  // Second, check the sizes.
2066  assert(extractStridedMetadata.getSizes().size() ==
2067  op.getMixedSizes().size() &&
2068  "Strides and sizes rank must match");
2069  SmallVector<OpFoldResult> extractSizesOfr =
2070  extractStridedMetadata.getConstifiedMixedSizes();
2071  SmallVector<OpFoldResult> reinterpretSizesOfr =
2072  op.getConstifiedMixedSizes();
2073  for (unsigned i = 0; i < rank; ++i) {
2074  if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2075  return failure();
2076  }
2077  // Finally, check the offset.
2078  assert(op.getMixedOffsets().size() == 1 &&
2079  "reinterpret_cast with more than one offset should have been "
2080  "rejected by the verifier");
2081  OpFoldResult extractOffsetOfr =
2082  extractStridedMetadata.getConstifiedMixedOffset();
2083  OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2084  if (extractOffsetOfr != reinterpretOffsetOfr)
2085  return failure();
2086 
2087  // At this point, we know that the back and forth between extract strided
2088  // metadata and reinterpret cast is a noop. However, the final type of the
2089  // reinterpret cast may not be exactly the same as the original memref.
2090  // E.g., it could be changing a dimension from static to dynamic. Check that
2091  // here and add a cast if necessary.
2092  Type srcTy = extractStridedMetadata.getSource().getType();
2093  if (srcTy == op.getResult().getType())
2094  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2095  else
2096  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2097  extractStridedMetadata.getSource());
2098 
2099  return success();
2100  }
2101 };
2102 } // namespace
2103 
2104 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2105  MLIRContext *context) {
2106  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2107 }
2108 
2109 //===----------------------------------------------------------------------===//
2110 // Reassociative reshape ops
2111 //===----------------------------------------------------------------------===//
2112 
2113 void CollapseShapeOp::getAsmResultNames(
2114  function_ref<void(Value, StringRef)> setNameFn) {
2115  setNameFn(getResult(), "collapse_shape");
2116 }
2117 
2118 void ExpandShapeOp::getAsmResultNames(
2119  function_ref<void(Value, StringRef)> setNameFn) {
2120  setNameFn(getResult(), "expand_shape");
2121 }
2122 
2123 LogicalResult ExpandShapeOp::reifyResultShapes(
2124  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2125  reifiedResultShapes = {
2126  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2127  return success();
2128 }
2129 
2130 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2131 /// result and operand. Layout maps are verified separately.
2132 ///
2133 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2134 /// allowed in a reassocation group.
2135 static LogicalResult
2137  ArrayRef<int64_t> expandedShape,
2138  ArrayRef<ReassociationIndices> reassociation,
2139  bool allowMultipleDynamicDimsPerGroup) {
2140  // There must be one reassociation group per collapsed dimension.
2141  if (collapsedShape.size() != reassociation.size())
2142  return op->emitOpError("invalid number of reassociation groups: found ")
2143  << reassociation.size() << ", expected " << collapsedShape.size();
2144 
2145  // The next expected expanded dimension index (while iterating over
2146  // reassociation indices).
2147  int64_t nextDim = 0;
2148  for (const auto &it : llvm::enumerate(reassociation)) {
2149  ReassociationIndices group = it.value();
2150  int64_t collapsedDim = it.index();
2151 
2152  bool foundDynamic = false;
2153  for (int64_t expandedDim : group) {
2154  if (expandedDim != nextDim++)
2155  return op->emitOpError("reassociation indices must be contiguous");
2156 
2157  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2158  return op->emitOpError("reassociation index ")
2159  << expandedDim << " is out of bounds";
2160 
2161  // Check if there are multiple dynamic dims in a reassociation group.
2162  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2163  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2164  return op->emitOpError(
2165  "at most one dimension in a reassociation group may be dynamic");
2166  foundDynamic = true;
2167  }
2168  }
2169 
2170  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2171  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2172  return op->emitOpError("collapsed dim (")
2173  << collapsedDim
2174  << ") must be dynamic if and only if reassociation group is "
2175  "dynamic";
2176 
2177  // If all dims in the reassociation group are static, the size of the
2178  // collapsed dim can be verified.
2179  if (!foundDynamic) {
2180  int64_t groupSize = 1;
2181  for (int64_t expandedDim : group)
2182  groupSize *= expandedShape[expandedDim];
2183  if (groupSize != collapsedShape[collapsedDim])
2184  return op->emitOpError("collapsed dim size (")
2185  << collapsedShape[collapsedDim]
2186  << ") must equal reassociation group size (" << groupSize << ")";
2187  }
2188  }
2189 
2190  if (collapsedShape.empty()) {
2191  // Rank 0: All expanded dimensions must be 1.
2192  for (int64_t d : expandedShape)
2193  if (d != 1)
2194  return op->emitOpError(
2195  "rank 0 memrefs can only be extended/collapsed with/from ones");
2196  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2197  // Rank >= 1: Number of dimensions among all reassociation groups must match
2198  // the result memref rank.
2199  return op->emitOpError("expanded rank (")
2200  << expandedShape.size()
2201  << ") inconsistent with number of reassociation indices (" << nextDim
2202  << ")";
2203  }
2204 
2205  return success();
2206 }
2207 
2208 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2209  return getSymbolLessAffineMaps(getReassociationExprs());
2210 }
2211 
2212 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2214  getReassociationIndices());
2215 }
2216 
2217 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2218  return getSymbolLessAffineMaps(getReassociationExprs());
2219 }
2220 
2221 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2223  getReassociationIndices());
2224 }
2225 
2226 /// Compute the layout map after expanding a given source MemRef type with the
2227 /// specified reassociation indices.
2228 static FailureOr<StridedLayoutAttr>
2229 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2230  ArrayRef<ReassociationIndices> reassociation) {
2231  int64_t srcOffset;
2232  SmallVector<int64_t> srcStrides;
2233  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2234  return failure();
2235  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2236 
2237  // 1-1 mapping between srcStrides and reassociation packs.
2238  // Each srcStride starts with the given value and gets expanded according to
2239  // the proper entries in resultShape.
2240  // Example:
2241  // srcStrides = [10000, 1 , 100 ],
2242  // reassociations = [ [0], [1], [2, 3, 4]],
2243  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2244  // -> For the purpose of stride calculation, the useful sizes are:
2245  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2246  // resultStrides = [10000, 1, 600, 200, 100]
2247  // Note that a stride does not get expanded along the first entry of each
2248  // shape pack.
2249  SmallVector<int64_t> reverseResultStrides;
2250  reverseResultStrides.reserve(resultShape.size());
2251  unsigned shapeIndex = resultShape.size() - 1;
2252  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2253  ReassociationIndices reassoc = std::get<0>(it);
2254  int64_t currentStrideToExpand = std::get<1>(it);
2255  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2256  reverseResultStrides.push_back(currentStrideToExpand);
2257  currentStrideToExpand =
2258  (SaturatedInteger::wrap(currentStrideToExpand) *
2259  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2260  .asInteger();
2261  }
2262  }
2263  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2264  resultStrides.resize(resultShape.size(), 1);
2265  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2266 }
2267 
2268 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2269  MemRefType srcType, ArrayRef<int64_t> resultShape,
2270  ArrayRef<ReassociationIndices> reassociation) {
2271  if (srcType.getLayout().isIdentity()) {
2272  // If the source is contiguous (i.e., no layout map specified), so is the
2273  // result.
2274  MemRefLayoutAttrInterface layout;
2275  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2276  srcType.getMemorySpace());
2277  }
2278 
2279  // Source may not be contiguous. Compute the layout map.
2280  FailureOr<StridedLayoutAttr> computedLayout =
2281  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2282  if (failed(computedLayout))
2283  return failure();
2284  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2285  srcType.getMemorySpace());
2286 }
2287 
2288 FailureOr<SmallVector<OpFoldResult>>
2289 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2290  MemRefType expandedType,
2291  ArrayRef<ReassociationIndices> reassociation,
2292  ArrayRef<OpFoldResult> inputShape) {
2293  std::optional<SmallVector<OpFoldResult>> outputShape =
2294  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2295  inputShape);
2296  if (!outputShape)
2297  return failure();
2298  return *outputShape;
2299 }
2300 
2301 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2302  Type resultType, Value src,
2303  ArrayRef<ReassociationIndices> reassociation,
2304  ArrayRef<OpFoldResult> outputShape) {
2305  auto [staticOutputShape, dynamicOutputShape] =
2307  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2308  getReassociationIndicesAttribute(builder, reassociation),
2309  dynamicOutputShape, staticOutputShape);
2310 }
2311 
2312 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2313  Type resultType, Value src,
2314  ArrayRef<ReassociationIndices> reassociation) {
2315  SmallVector<OpFoldResult> inputShape =
2316  getMixedSizes(builder, result.location, src);
2317  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2318  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2319  builder, result.location, memrefResultTy, reassociation, inputShape);
2320  // Failure of this assertion usually indicates presence of multiple
2321  // dynamic dimensions in the same reassociation group.
2322  assert(succeeded(outputShape) && "unable to infer output shape");
2323  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2324 }
2325 
2326 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2327  ArrayRef<int64_t> resultShape, Value src,
2328  ArrayRef<ReassociationIndices> reassociation) {
2329  // Only ranked memref source values are supported.
2330  auto srcType = llvm::cast<MemRefType>(src.getType());
2331  FailureOr<MemRefType> resultType =
2332  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2333  // Failure of this assertion usually indicates a problem with the source
2334  // type, e.g., could not get strides/offset.
2335  assert(succeeded(resultType) && "could not compute layout");
2336  build(builder, result, *resultType, src, reassociation);
2337 }
2338 
2339 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2340  ArrayRef<int64_t> resultShape, Value src,
2341  ArrayRef<ReassociationIndices> reassociation,
2342  ArrayRef<OpFoldResult> outputShape) {
2343  // Only ranked memref source values are supported.
2344  auto srcType = llvm::cast<MemRefType>(src.getType());
2345  FailureOr<MemRefType> resultType =
2346  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2347  // Failure of this assertion usually indicates a problem with the source
2348  // type, e.g., could not get strides/offset.
2349  assert(succeeded(resultType) && "could not compute layout");
2350  build(builder, result, *resultType, src, reassociation, outputShape);
2351 }
2352 
2353 LogicalResult ExpandShapeOp::verify() {
2354  MemRefType srcType = getSrcType();
2355  MemRefType resultType = getResultType();
2356 
2357  if (srcType.getRank() > resultType.getRank()) {
2358  auto r0 = srcType.getRank();
2359  auto r1 = resultType.getRank();
2360  return emitOpError("has source rank ")
2361  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2362  << r0 << " > " << r1 << ").";
2363  }
2364 
2365  // Verify result shape.
2366  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2367  resultType.getShape(),
2368  getReassociationIndices(),
2369  /*allowMultipleDynamicDimsPerGroup=*/true)))
2370  return failure();
2371 
2372  // Compute expected result type (including layout map).
2373  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2374  srcType, resultType.getShape(), getReassociationIndices());
2375  if (failed(expectedResultType))
2376  return emitOpError("invalid source layout map");
2377 
2378  // Check actual result type.
2379  if (*expectedResultType != resultType)
2380  return emitOpError("expected expanded type to be ")
2381  << *expectedResultType << " but found " << resultType;
2382 
2383  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2384  return emitOpError("expected number of static shape bounds to be equal to "
2385  "the output rank (")
2386  << resultType.getRank() << ") but found "
2387  << getStaticOutputShape().size() << " inputs instead";
2388 
2389  if ((int64_t)getOutputShape().size() !=
2390  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2391  return emitOpError("mismatch in dynamic dims in output_shape and "
2392  "static_output_shape: static_output_shape has ")
2393  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2394  << " dynamic dims while output_shape has " << getOutputShape().size()
2395  << " values";
2396 
2397  // Verify if provided output shapes are in agreement with output type.
2398  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2399  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2400  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2401  if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2402  return emitOpError("invalid output shape provided at pos ") << pos;
2403  }
2404  }
2405 
2406  return success();
2407 }
2408 
2409 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2410  MLIRContext *context) {
2411  results.add<
2414 }
2415 
2416 /// Compute the layout map after collapsing a given source MemRef type with the
2417 /// specified reassociation indices.
2418 ///
2419 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2420 /// not possible to check this by inspecting a MemRefType in the general case.
2421 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2422 /// be valid (and thus accepted by this function) unless `strict = true`.
2423 static FailureOr<StridedLayoutAttr>
2424 computeCollapsedLayoutMap(MemRefType srcType,
2425  ArrayRef<ReassociationIndices> reassociation,
2426  bool strict = false) {
2427  int64_t srcOffset;
2428  SmallVector<int64_t> srcStrides;
2429  auto srcShape = srcType.getShape();
2430  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2431  return failure();
2432 
2433  // The result stride of a reassociation group is the stride of the last entry
2434  // of the reassociation. (TODO: Should be the minimum stride in the
2435  // reassociation because strides are not necessarily sorted. E.g., when using
2436  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2437  // strides are meaningless and could have any arbitrary value.
2438  SmallVector<int64_t> resultStrides;
2439  resultStrides.reserve(reassociation.size());
2440  for (const ReassociationIndices &reassoc : reassociation) {
2441  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2442  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2443  ref = ref.drop_back();
2444  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2445  resultStrides.push_back(srcStrides[ref.back()]);
2446  } else {
2447  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2448  // the corresponding stride may have to be skipped. (See above comment.)
2449  // Therefore, the result stride cannot be statically determined and must
2450  // be dynamic.
2451  resultStrides.push_back(ShapedType::kDynamic);
2452  }
2453  }
2454 
2455  // Validate that each reassociation group is contiguous.
2456  unsigned resultStrideIndex = resultStrides.size() - 1;
2457  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2458  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2459  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2460  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2461  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2462 
2463  // Both source and result stride must have the same static value. In that
2464  // case, we can be sure, that the dimensions are collapsible (because they
2465  // are contiguous).
2466  // If `strict = false` (default during op verification), we accept cases
2467  // where one or both strides are dynamic. This is best effort: We reject
2468  // ops where obviously non-contiguous dims are collapsed, but accept ops
2469  // where we cannot be sure statically. Such ops may fail at runtime. See
2470  // the op documentation for details.
2471  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2472  if (strict && (stride.saturated || srcStride.saturated))
2473  return failure();
2474 
2475  // Dimensions of size 1 should be skipped, because their strides are
2476  // meaningless and could have any arbitrary value.
2477  if (srcShape[idx - 1] == 1)
2478  continue;
2479 
2480  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2481  return failure();
2482  }
2483  }
2484  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2485 }
2486 
2487 bool CollapseShapeOp::isGuaranteedCollapsible(
2488  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2489  // MemRefs with identity layout are always collapsible.
2490  if (srcType.getLayout().isIdentity())
2491  return true;
2492 
2493  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2494  /*strict=*/true));
2495 }
2496 
2497 MemRefType CollapseShapeOp::computeCollapsedType(
2498  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2499  SmallVector<int64_t> resultShape;
2500  resultShape.reserve(reassociation.size());
2501  for (const ReassociationIndices &group : reassociation) {
2502  auto groupSize = SaturatedInteger::wrap(1);
2503  for (int64_t srcDim : group)
2504  groupSize =
2505  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2506  resultShape.push_back(groupSize.asInteger());
2507  }
2508 
2509  if (srcType.getLayout().isIdentity()) {
2510  // If the source is contiguous (i.e., no layout map specified), so is the
2511  // result.
2512  MemRefLayoutAttrInterface layout;
2513  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2514  srcType.getMemorySpace());
2515  }
2516 
2517  // Source may not be fully contiguous. Compute the layout map.
2518  // Note: Dimensions that are collapsed into a single dim are assumed to be
2519  // contiguous.
2520  FailureOr<StridedLayoutAttr> computedLayout =
2521  computeCollapsedLayoutMap(srcType, reassociation);
2522  assert(succeeded(computedLayout) &&
2523  "invalid source layout map or collapsing non-contiguous dims");
2524  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2525  srcType.getMemorySpace());
2526 }
2527 
2528 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2529  ArrayRef<ReassociationIndices> reassociation,
2530  ArrayRef<NamedAttribute> attrs) {
2531  auto srcType = llvm::cast<MemRefType>(src.getType());
2532  MemRefType resultType =
2533  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2535  getReassociationIndicesAttribute(b, reassociation));
2536  build(b, result, resultType, src, attrs);
2537 }
2538 
2539 LogicalResult CollapseShapeOp::verify() {
2540  MemRefType srcType = getSrcType();
2541  MemRefType resultType = getResultType();
2542 
2543  if (srcType.getRank() < resultType.getRank()) {
2544  auto r0 = srcType.getRank();
2545  auto r1 = resultType.getRank();
2546  return emitOpError("has source rank ")
2547  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2548  << r0 << " < " << r1 << ").";
2549  }
2550 
2551  // Verify result shape.
2552  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2553  srcType.getShape(), getReassociationIndices(),
2554  /*allowMultipleDynamicDimsPerGroup=*/true)))
2555  return failure();
2556 
2557  // Compute expected result type (including layout map).
2558  MemRefType expectedResultType;
2559  if (srcType.getLayout().isIdentity()) {
2560  // If the source is contiguous (i.e., no layout map specified), so is the
2561  // result.
2562  MemRefLayoutAttrInterface layout;
2563  expectedResultType =
2564  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2565  srcType.getMemorySpace());
2566  } else {
2567  // Source may not be fully contiguous. Compute the layout map.
2568  // Note: Dimensions that are collapsed into a single dim are assumed to be
2569  // contiguous.
2570  FailureOr<StridedLayoutAttr> computedLayout =
2571  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2572  if (failed(computedLayout))
2573  return emitOpError(
2574  "invalid source layout map or collapsing non-contiguous dims");
2575  expectedResultType =
2576  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2577  *computedLayout, srcType.getMemorySpace());
2578  }
2579 
2580  if (expectedResultType != resultType)
2581  return emitOpError("expected collapsed type to be ")
2582  << expectedResultType << " but found " << resultType;
2583 
2584  return success();
2585 }
2586 
2588  : public OpRewritePattern<CollapseShapeOp> {
2589 public:
2591 
2592  LogicalResult matchAndRewrite(CollapseShapeOp op,
2593  PatternRewriter &rewriter) const override {
2594  auto cast = op.getOperand().getDefiningOp<CastOp>();
2595  if (!cast)
2596  return failure();
2597 
2598  if (!CastOp::canFoldIntoConsumerOp(cast))
2599  return failure();
2600 
2601  Type newResultType = CollapseShapeOp::computeCollapsedType(
2602  llvm::cast<MemRefType>(cast.getOperand().getType()),
2603  op.getReassociationIndices());
2604 
2605  if (newResultType == op.getResultType()) {
2606  rewriter.modifyOpInPlace(
2607  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2608  } else {
2609  Value newOp = rewriter.create<CollapseShapeOp>(
2610  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2611  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2612  }
2613  return success();
2614  }
2615 };
2616 
2617 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2618  MLIRContext *context) {
2619  results.add<
2621  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2622  memref::DimOp, MemRefType>,
2624 }
2625 
2626 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2627  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2628  adaptor.getOperands());
2629 }
2630 
2631 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2632  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2633  adaptor.getOperands());
2634 }
2635 
2636 //===----------------------------------------------------------------------===//
2637 // ReshapeOp
2638 //===----------------------------------------------------------------------===//
2639 
2640 void ReshapeOp::getAsmResultNames(
2641  function_ref<void(Value, StringRef)> setNameFn) {
2642  setNameFn(getResult(), "reshape");
2643 }
2644 
2645 LogicalResult ReshapeOp::verify() {
2646  Type operandType = getSource().getType();
2647  Type resultType = getResult().getType();
2648 
2649  Type operandElementType =
2650  llvm::cast<ShapedType>(operandType).getElementType();
2651  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2652  if (operandElementType != resultElementType)
2653  return emitOpError("element types of source and destination memref "
2654  "types should be the same");
2655 
2656  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2657  if (!operandMemRefType.getLayout().isIdentity())
2658  return emitOpError("source memref type should have identity affine map");
2659 
2660  int64_t shapeSize =
2661  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2662  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2663  if (resultMemRefType) {
2664  if (!resultMemRefType.getLayout().isIdentity())
2665  return emitOpError("result memref type should have identity affine map");
2666  if (shapeSize == ShapedType::kDynamic)
2667  return emitOpError("cannot use shape operand with dynamic length to "
2668  "reshape to statically-ranked memref type");
2669  if (shapeSize != resultMemRefType.getRank())
2670  return emitOpError(
2671  "length of shape operand differs from the result's memref rank");
2672  }
2673  return success();
2674 }
2675 
2676 //===----------------------------------------------------------------------===//
2677 // StoreOp
2678 //===----------------------------------------------------------------------===//
2679 
2680 LogicalResult StoreOp::verify() {
2681  if (getNumOperands() != 2 + getMemRefType().getRank())
2682  return emitOpError("store index operand count not equal to memref rank");
2683 
2684  return success();
2685 }
2686 
2687 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2688  SmallVectorImpl<OpFoldResult> &results) {
2689  /// store(memrefcast) -> store
2690  return foldMemRefCast(*this, getValueToStore());
2691 }
2692 
2693 //===----------------------------------------------------------------------===//
2694 // SubViewOp
2695 //===----------------------------------------------------------------------===//
2696 
2697 void SubViewOp::getAsmResultNames(
2698  function_ref<void(Value, StringRef)> setNameFn) {
2699  setNameFn(getResult(), "subview");
2700 }
2701 
2702 /// A subview result type can be fully inferred from the source type and the
2703 /// static representation of offsets, sizes and strides. Special sentinels
2704 /// encode the dynamic case.
2705 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2706  ArrayRef<int64_t> staticOffsets,
2707  ArrayRef<int64_t> staticSizes,
2708  ArrayRef<int64_t> staticStrides) {
2709  unsigned rank = sourceMemRefType.getRank();
2710  (void)rank;
2711  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2712  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2713  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2714 
2715  // Extract source offset and strides.
2716  auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2717 
2718  // Compute target offset whose value is:
2719  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2720  int64_t targetOffset = sourceOffset;
2721  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2722  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2723  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2724  SaturatedInteger::wrap(staticOffset) *
2725  SaturatedInteger::wrap(sourceStride))
2726  .asInteger();
2727  }
2728 
2729  // Compute target stride whose value is:
2730  // `sourceStrides_i * staticStrides_i`.
2731  SmallVector<int64_t, 4> targetStrides;
2732  targetStrides.reserve(staticOffsets.size());
2733  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2734  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2735  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2736  SaturatedInteger::wrap(staticStride))
2737  .asInteger());
2738  }
2739 
2740  // The type is now known.
2741  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2742  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2743  targetOffset, targetStrides),
2744  sourceMemRefType.getMemorySpace());
2745 }
2746 
2747 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2748  ArrayRef<OpFoldResult> offsets,
2749  ArrayRef<OpFoldResult> sizes,
2750  ArrayRef<OpFoldResult> strides) {
2751  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2752  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2753  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2754  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2755  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2756  if (!hasValidSizesOffsets(staticOffsets))
2757  return {};
2758  if (!hasValidSizesOffsets(staticSizes))
2759  return {};
2760  if (!hasValidStrides(staticStrides))
2761  return {};
2762  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2763  staticSizes, staticStrides);
2764 }
2765 
2766 MemRefType SubViewOp::inferRankReducedResultType(
2767  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2768  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2769  ArrayRef<int64_t> strides) {
2770  MemRefType inferredType =
2771  inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2772  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2773  "expected ");
2774  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2775  return inferredType;
2776 
2777  // Compute which dimensions are dropped.
2778  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2779  computeRankReductionMask(inferredType.getShape(), resultShape);
2780  assert(dimsToProject.has_value() && "invalid rank reduction");
2781 
2782  // Compute the layout and result type.
2783  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2784  SmallVector<int64_t> rankReducedStrides;
2785  rankReducedStrides.reserve(resultShape.size());
2786  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2787  if (!dimsToProject->contains(idx))
2788  rankReducedStrides.push_back(value);
2789  }
2790  return MemRefType::get(resultShape, inferredType.getElementType(),
2791  StridedLayoutAttr::get(inferredLayout.getContext(),
2792  inferredLayout.getOffset(),
2793  rankReducedStrides),
2794  inferredType.getMemorySpace());
2795 }
2796 
2797 MemRefType SubViewOp::inferRankReducedResultType(
2798  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2800  ArrayRef<OpFoldResult> strides) {
2801  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2802  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2803  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2804  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2805  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2806  return SubViewOp::inferRankReducedResultType(
2807  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2808  staticStrides);
2809 }
2810 
2811 // Build a SubViewOp with mixed static and dynamic entries and custom result
2812 // type. If the type passed is nullptr, it is inferred.
2813 void SubViewOp::build(OpBuilder &b, OperationState &result,
2814  MemRefType resultType, Value source,
2815  ArrayRef<OpFoldResult> offsets,
2816  ArrayRef<OpFoldResult> sizes,
2817  ArrayRef<OpFoldResult> strides,
2818  ArrayRef<NamedAttribute> attrs) {
2819  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2820  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2821  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2822  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2823  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2824  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2825  // Structuring implementation this way avoids duplication between builders.
2826  if (!resultType) {
2827  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2828  staticSizes, staticStrides);
2829  }
2830  result.addAttributes(attrs);
2831  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2832  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2833  b.getDenseI64ArrayAttr(staticSizes),
2834  b.getDenseI64ArrayAttr(staticStrides));
2835 }
2836 
2837 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2838 // type.
2839 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2840  ArrayRef<OpFoldResult> offsets,
2841  ArrayRef<OpFoldResult> sizes,
2842  ArrayRef<OpFoldResult> strides,
2843  ArrayRef<NamedAttribute> attrs) {
2844  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2845 }
2846 
2847 // Build a SubViewOp with static entries and inferred result type.
2848 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2849  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2850  ArrayRef<int64_t> strides,
2851  ArrayRef<NamedAttribute> attrs) {
2852  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2853  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2854  return b.getI64IntegerAttr(v);
2855  }));
2856  SmallVector<OpFoldResult> sizeValues =
2857  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2858  return b.getI64IntegerAttr(v);
2859  }));
2860  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2861  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2862  return b.getI64IntegerAttr(v);
2863  }));
2864  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2865 }
2866 
2867 // Build a SubViewOp with dynamic entries and custom result type. If the
2868 // type passed is nullptr, it is inferred.
2869 void SubViewOp::build(OpBuilder &b, OperationState &result,
2870  MemRefType resultType, Value source,
2871  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2872  ArrayRef<int64_t> strides,
2873  ArrayRef<NamedAttribute> attrs) {
2874  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2875  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2876  return b.getI64IntegerAttr(v);
2877  }));
2878  SmallVector<OpFoldResult> sizeValues =
2879  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2880  return b.getI64IntegerAttr(v);
2881  }));
2882  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2883  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2884  return b.getI64IntegerAttr(v);
2885  }));
2886  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2887  attrs);
2888 }
2889 
2890 // Build a SubViewOp with dynamic entries and custom result type. If the type
2891 // passed is nullptr, it is inferred.
2892 void SubViewOp::build(OpBuilder &b, OperationState &result,
2893  MemRefType resultType, Value source, ValueRange offsets,
2894  ValueRange sizes, ValueRange strides,
2895  ArrayRef<NamedAttribute> attrs) {
2896  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2897  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2898  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2899  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2900  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2901  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2902  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2903 }
2904 
2905 // Build a SubViewOp with dynamic entries and inferred result type.
2906 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2907  ValueRange offsets, ValueRange sizes, ValueRange strides,
2908  ArrayRef<NamedAttribute> attrs) {
2909  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2910 }
2911 
2912 /// For ViewLikeOpInterface.
2913 Value SubViewOp::getViewSource() { return getSource(); }
2914 
2915 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2916 /// static value).
2917 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2918  int64_t t1Offset, t2Offset;
2919  SmallVector<int64_t> t1Strides, t2Strides;
2920  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2921  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2922  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2923 }
2924 
2925 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2926 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2927 /// marked as dropped in `droppedDims`.
2928 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2929  const llvm::SmallBitVector &droppedDims) {
2930  assert(size_t(t1.getRank()) == droppedDims.size() &&
2931  "incorrect number of bits");
2932  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2933  "incorrect number of dropped dims");
2934  int64_t t1Offset, t2Offset;
2935  SmallVector<int64_t> t1Strides, t2Strides;
2936  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2937  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2938  if (failed(res1) || failed(res2))
2939  return false;
2940  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2941  if (droppedDims[i])
2942  continue;
2943  if (t1Strides[i] != t2Strides[j])
2944  return false;
2945  ++j;
2946  }
2947  return true;
2948 }
2949 
2951  Operation *op, Type expectedType) {
2952  auto memrefType = llvm::cast<ShapedType>(expectedType);
2953  switch (result) {
2955  return success();
2957  return op->emitError("expected result rank to be smaller or equal to ")
2958  << "the source rank. ";
2960  return op->emitError("expected result type to be ")
2961  << expectedType
2962  << " or a rank-reduced version. (mismatch of result sizes) ";
2964  return op->emitError("expected result element type to be ")
2965  << memrefType.getElementType();
2967  return op->emitError("expected result and source memory spaces to match.");
2969  return op->emitError("expected result type to be ")
2970  << expectedType
2971  << " or a rank-reduced version. (mismatch of result layout) ";
2972  }
2973  llvm_unreachable("unexpected subview verification result");
2974 }
2975 
2976 /// Verifier for SubViewOp.
2977 LogicalResult SubViewOp::verify() {
2978  MemRefType baseType = getSourceType();
2979  MemRefType subViewType = getType();
2980 
2981  // The base memref and the view memref should be in the same memory space.
2982  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2983  return emitError("different memory spaces specified for base memref "
2984  "type ")
2985  << baseType << " and subview memref type " << subViewType;
2986 
2987  // Verify that the base memref type has a strided layout map.
2988  if (!baseType.isStrided())
2989  return emitError("base type ") << baseType << " is not strided";
2990 
2991  // Compute the expected result type, assuming that there are no rank
2992  // reductions.
2993  MemRefType expectedType = SubViewOp::inferResultType(
2994  baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
2995 
2996  // Verify all properties of a shaped type: rank, element type and dimension
2997  // sizes. This takes into account potential rank reductions.
2998  auto shapedTypeVerification = isRankReducedType(
2999  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
3000  if (shapedTypeVerification != SliceVerificationResult::Success)
3001  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
3002 
3003  // Make sure that the memory space did not change.
3004  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3006  *this, expectedType);
3007 
3008  // Verify the offset of the layout map.
3009  if (!haveCompatibleOffsets(expectedType, subViewType))
3011  *this, expectedType);
3012 
3013  // The only thing that's left to verify now are the strides. First, compute
3014  // the unused dimensions due to rank reductions. We have to look at sizes and
3015  // strides to decide which dimensions were dropped. This function also
3016  // partially verifies strides in case of rank reductions.
3017  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3018  getMixedSizes());
3019  if (failed(unusedDims))
3021  *this, expectedType);
3022 
3023  // Strides must match.
3024  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3026  *this, expectedType);
3027 
3028  return success();
3029 }
3030 
3031 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
3032  return os << "range " << range.offset << ":" << range.size << ":"
3033  << range.stride;
3034 }
3035 
3036 /// Return the list of Range (i.e. offset, size, stride). Each Range
3037 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3038 /// with `b` at location `loc`.
3039 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3040  OpBuilder &b, Location loc) {
3041  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3042  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3043  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3045  unsigned rank = ranks[0];
3046  res.reserve(rank);
3047  for (unsigned idx = 0; idx < rank; ++idx) {
3048  Value offset =
3049  op.isDynamicOffset(idx)
3050  ? op.getDynamicOffset(idx)
3051  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
3052  Value size =
3053  op.isDynamicSize(idx)
3054  ? op.getDynamicSize(idx)
3055  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
3056  Value stride =
3057  op.isDynamicStride(idx)
3058  ? op.getDynamicStride(idx)
3059  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
3060  res.emplace_back(Range{offset, size, stride});
3061  }
3062  return res;
3063 }
3064 
3065 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3066 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3067 /// the rank of the inferred result type if `currentResultType` is lower rank
3068 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3069 /// together with the result type. In this case, it is important to compute
3070 /// the dropped dimensions using `currentSourceType` whose strides align with
3071 /// `currentResultType`.
3073  MemRefType currentResultType, MemRefType currentSourceType,
3074  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3075  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3076  MemRefType nonRankReducedType = SubViewOp::inferResultType(
3077  sourceType, mixedOffsets, mixedSizes, mixedStrides);
3078  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3079  currentSourceType, currentResultType, mixedSizes);
3080  if (failed(unusedDims))
3081  return nullptr;
3082 
3083  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3084  SmallVector<int64_t> shape, strides;
3085  unsigned numDimsAfterReduction =
3086  nonRankReducedType.getRank() - unusedDims->count();
3087  shape.reserve(numDimsAfterReduction);
3088  strides.reserve(numDimsAfterReduction);
3089  for (const auto &[idx, size, stride] :
3090  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3091  nonRankReducedType.getShape(), layout.getStrides())) {
3092  if (unusedDims->test(idx))
3093  continue;
3094  shape.push_back(size);
3095  strides.push_back(stride);
3096  }
3097 
3098  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3099  StridedLayoutAttr::get(sourceType.getContext(),
3100  layout.getOffset(), strides),
3101  nonRankReducedType.getMemorySpace());
3102 }
3103 
3105  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3106  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3107  unsigned rank = memrefType.getRank();
3108  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3109  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3110  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3111  MemRefType targetType = SubViewOp::inferRankReducedResultType(
3112  targetShape, memrefType, offsets, sizes, strides);
3113  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3114  sizes, strides);
3115 }
3116 
3117 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3118  Value value,
3119  ArrayRef<int64_t> desiredShape) {
3120  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3121  assert(sourceMemrefType && "not a ranked memref type");
3122  auto sourceShape = sourceMemrefType.getShape();
3123  if (sourceShape.equals(desiredShape))
3124  return value;
3125  auto maybeRankReductionMask =
3126  mlir::computeRankReductionMask(sourceShape, desiredShape);
3127  if (!maybeRankReductionMask)
3128  return failure();
3129  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3130 }
3131 
3132 /// Helper method to check if a `subview` operation is trivially a no-op. This
3133 /// is the case if the all offsets are zero, all strides are 1, and the source
3134 /// shape is same as the size of the subview. In such cases, the subview can
3135 /// be folded into its source.
3136 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3137  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3138  return false;
3139 
3140  auto mixedOffsets = subViewOp.getMixedOffsets();
3141  auto mixedSizes = subViewOp.getMixedSizes();
3142  auto mixedStrides = subViewOp.getMixedStrides();
3143 
3144  // Check offsets are zero.
3145  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3146  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3147  return !intValue || intValue.value() != 0;
3148  }))
3149  return false;
3150 
3151  // Check strides are one.
3152  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3153  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3154  return !intValue || intValue.value() != 1;
3155  }))
3156  return false;
3157 
3158  // Check all size values are static and matches the (static) source shape.
3159  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3160  for (const auto &size : llvm::enumerate(mixedSizes)) {
3161  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3162  if (!intValue || *intValue != sourceShape[size.index()])
3163  return false;
3164  }
3165  // All conditions met. The `SubViewOp` is foldable as a no-op.
3166  return true;
3167 }
3168 
3169 namespace {
3170 /// Pattern to rewrite a subview op with MemRefCast arguments.
3171 /// This essentially pushes memref.cast past its consuming subview when
3172 /// `canFoldIntoConsumerOp` is true.
3173 ///
3174 /// Example:
3175 /// ```
3176 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3177 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3178 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3179 /// ```
3180 /// is rewritten into:
3181 /// ```
3182 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3183 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3184 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3185 /// ```
3186 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3187 public:
3189 
3190  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3191  PatternRewriter &rewriter) const override {
3192  // Any constant operand, just return to let SubViewOpConstantFolder kick
3193  // in.
3194  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3195  return matchPattern(operand, matchConstantIndex());
3196  }))
3197  return failure();
3198 
3199  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3200  if (!castOp)
3201  return failure();
3202 
3203  if (!CastOp::canFoldIntoConsumerOp(castOp))
3204  return failure();
3205 
3206  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3207  // the MemRefCastOp source operand type to infer the result type and the
3208  // current SubViewOp source operand type to compute the dropped dimensions
3209  // if the operation is rank-reducing.
3210  auto resultType = getCanonicalSubViewResultType(
3211  subViewOp.getType(), subViewOp.getSourceType(),
3212  llvm::cast<MemRefType>(castOp.getSource().getType()),
3213  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3214  subViewOp.getMixedStrides());
3215  if (!resultType)
3216  return failure();
3217 
3218  Value newSubView = rewriter.create<SubViewOp>(
3219  subViewOp.getLoc(), resultType, castOp.getSource(),
3220  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3221  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3222  subViewOp.getStaticStrides());
3223  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3224  newSubView);
3225  return success();
3226  }
3227 };
3228 
3229 /// Canonicalize subview ops that are no-ops. When the source shape is not
3230 /// same as a result shape due to use of `affine_map`.
3231 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3232 public:
3234 
3235  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3236  PatternRewriter &rewriter) const override {
3237  if (!isTrivialSubViewOp(subViewOp))
3238  return failure();
3239  if (subViewOp.getSourceType() == subViewOp.getType()) {
3240  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3241  return success();
3242  }
3243  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3244  subViewOp.getSource());
3245  return success();
3246  }
3247 };
3248 } // namespace
3249 
3250 /// Return the canonical type of the result of a subview.
3252  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3253  ArrayRef<OpFoldResult> mixedSizes,
3254  ArrayRef<OpFoldResult> mixedStrides) {
3255  // Infer a memref type without taking into account any rank reductions.
3256  MemRefType resTy = SubViewOp::inferResultType(
3257  op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3258  if (!resTy)
3259  return {};
3260  MemRefType nonReducedType = resTy;
3261 
3262  // Directly return the non-rank reduced type if there are no dropped dims.
3263  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3264  if (droppedDims.none())
3265  return nonReducedType;
3266 
3267  // Take the strides and offset from the non-rank reduced type.
3268  auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3269 
3270  // Drop dims from shape and strides.
3271  SmallVector<int64_t> targetShape;
3272  SmallVector<int64_t> targetStrides;
3273  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3274  if (droppedDims.test(i))
3275  continue;
3276  targetStrides.push_back(nonReducedStrides[i]);
3277  targetShape.push_back(nonReducedType.getDimSize(i));
3278  }
3279 
3280  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3281  StridedLayoutAttr::get(nonReducedType.getContext(),
3282  offset, targetStrides),
3283  nonReducedType.getMemorySpace());
3284  }
3285 };
3286 
3287 /// A canonicalizer wrapper to replace SubViewOps.
3289  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3290  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3291  }
3292 };
3293 
3294 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3295  MLIRContext *context) {
3296  results
3299  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3300 }
3301 
3302 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3303  MemRefType sourceMemrefType = getSource().getType();
3304  MemRefType resultMemrefType = getResult().getType();
3305  auto resultLayout =
3306  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3307 
3308  if (resultMemrefType == sourceMemrefType &&
3309  resultMemrefType.hasStaticShape() &&
3310  (!resultLayout || resultLayout.hasStaticLayout())) {
3311  return getViewSource();
3312  }
3313 
3314  // Fold subview(subview(x)), where both subviews have the same size and the
3315  // second subview's offsets are all zero. (I.e., the second subview is a
3316  // no-op.)
3317  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3318  auto srcSizes = srcSubview.getMixedSizes();
3319  auto sizes = getMixedSizes();
3320  auto offsets = getMixedOffsets();
3321  bool allOffsetsZero = llvm::all_of(
3322  offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3323  auto strides = getMixedStrides();
3324  bool allStridesOne = llvm::all_of(
3325  strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3326  bool allSizesSame = llvm::equal(sizes, srcSizes);
3327  if (allOffsetsZero && allStridesOne && allSizesSame &&
3328  resultMemrefType == sourceMemrefType)
3329  return getViewSource();
3330  }
3331 
3332  return {};
3333 }
3334 
3335 //===----------------------------------------------------------------------===//
3336 // TransposeOp
3337 //===----------------------------------------------------------------------===//
3338 
3339 void TransposeOp::getAsmResultNames(
3340  function_ref<void(Value, StringRef)> setNameFn) {
3341  setNameFn(getResult(), "transpose");
3342 }
3343 
3344 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3345 static MemRefType inferTransposeResultType(MemRefType memRefType,
3346  AffineMap permutationMap) {
3347  auto originalSizes = memRefType.getShape();
3348  auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3349  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3350 
3351  // Compute permuted sizes and strides.
3352  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3353  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3354 
3355  return MemRefType::Builder(memRefType)
3356  .setShape(sizes)
3357  .setLayout(
3358  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3359 }
3360 
3361 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3362  AffineMapAttr permutation,
3363  ArrayRef<NamedAttribute> attrs) {
3364  auto permutationMap = permutation.getValue();
3365  assert(permutationMap);
3366 
3367  auto memRefType = llvm::cast<MemRefType>(in.getType());
3368  // Compute result type.
3369  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3370 
3371  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3372  build(b, result, resultType, in, attrs);
3373 }
3374 
3375 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3377  p << " " << getIn() << " " << getPermutation();
3378  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3379  p << " : " << getIn().getType() << " to " << getType();
3380 }
3381 
3382 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3384  AffineMap permutation;
3385  MemRefType srcType, dstType;
3386  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3387  parser.parseOptionalAttrDict(result.attributes) ||
3388  parser.parseColonType(srcType) ||
3389  parser.resolveOperand(in, srcType, result.operands) ||
3390  parser.parseKeywordType("to", dstType) ||
3391  parser.addTypeToList(dstType, result.types))
3392  return failure();
3393 
3394  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3395  AffineMapAttr::get(permutation));
3396  return success();
3397 }
3398 
3399 LogicalResult TransposeOp::verify() {
3400  if (!getPermutation().isPermutation())
3401  return emitOpError("expected a permutation map");
3402  if (getPermutation().getNumDims() != getIn().getType().getRank())
3403  return emitOpError("expected a permutation map of same rank as the input");
3404 
3405  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3406  auto resultType = llvm::cast<MemRefType>(getType());
3407  auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3408  .canonicalizeStridedLayout();
3409 
3410  if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3411  return emitOpError("result type ")
3412  << resultType
3413  << " is not equivalent to the canonical transposed input type "
3414  << canonicalResultType;
3415  return success();
3416 }
3417 
3418 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3419  // First check for identity permutation, we can fold it away if input and
3420  // result types are identical already.
3421  if (getPermutation().isIdentity() && getType() == getIn().getType())
3422  return getIn();
3423  // Fold two consecutive memref.transpose Ops into one by composing their
3424  // permutation maps.
3425  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3426  AffineMap composedPermutation =
3427  getPermutation().compose(otherTransposeOp.getPermutation());
3428  getInMutable().assign(otherTransposeOp.getIn());
3429  setPermutation(composedPermutation);
3430  return getResult();
3431  }
3432  return {};
3433 }
3434 
3435 //===----------------------------------------------------------------------===//
3436 // ViewOp
3437 //===----------------------------------------------------------------------===//
3438 
3439 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3440  setNameFn(getResult(), "view");
3441 }
3442 
3443 LogicalResult ViewOp::verify() {
3444  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3445  auto viewType = getType();
3446 
3447  // The base memref should have identity layout map (or none).
3448  if (!baseType.getLayout().isIdentity())
3449  return emitError("unsupported map for base memref type ") << baseType;
3450 
3451  // The result memref should have identity layout map (or none).
3452  if (!viewType.getLayout().isIdentity())
3453  return emitError("unsupported map for result memref type ") << viewType;
3454 
3455  // The base memref and the view memref should be in the same memory space.
3456  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3457  return emitError("different memory spaces specified for base memref "
3458  "type ")
3459  << baseType << " and view memref type " << viewType;
3460 
3461  // Verify that we have the correct number of sizes for the result type.
3462  unsigned numDynamicDims = viewType.getNumDynamicDims();
3463  if (getSizes().size() != numDynamicDims)
3464  return emitError("incorrect number of size operands for type ") << viewType;
3465 
3466  return success();
3467 }
3468 
3469 Value ViewOp::getViewSource() { return getSource(); }
3470 
3471 namespace {
3472 
3473 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3475 
3476  LogicalResult matchAndRewrite(ViewOp viewOp,
3477  PatternRewriter &rewriter) const override {
3478  // Return if none of the operands are constants.
3479  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3480  return matchPattern(operand, matchConstantIndex());
3481  }))
3482  return failure();
3483 
3484  // Get result memref type.
3485  auto memrefType = viewOp.getType();
3486 
3487  // Get offset from old memref view type 'memRefType'.
3488  int64_t oldOffset;
3489  SmallVector<int64_t, 4> oldStrides;
3490  if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3491  return failure();
3492  assert(oldOffset == 0 && "Expected 0 offset");
3493 
3494  SmallVector<Value, 4> newOperands;
3495 
3496  // Offset cannot be folded into result type.
3497 
3498  // Fold any dynamic dim operands which are produced by a constant.
3499  SmallVector<int64_t, 4> newShapeConstants;
3500  newShapeConstants.reserve(memrefType.getRank());
3501 
3502  unsigned dynamicDimPos = 0;
3503  unsigned rank = memrefType.getRank();
3504  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3505  int64_t dimSize = memrefType.getDimSize(dim);
3506  // If this is already static dimension, keep it.
3507  if (!ShapedType::isDynamic(dimSize)) {
3508  newShapeConstants.push_back(dimSize);
3509  continue;
3510  }
3511  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3512  if (auto constantIndexOp =
3513  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3514  // Dynamic shape dimension will be folded.
3515  newShapeConstants.push_back(constantIndexOp.value());
3516  } else {
3517  // Dynamic shape dimension not folded; copy operand from old memref.
3518  newShapeConstants.push_back(dimSize);
3519  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3520  }
3521  dynamicDimPos++;
3522  }
3523 
3524  // Create new memref type with constant folded dims.
3525  MemRefType newMemRefType =
3526  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3527  // Nothing new, don't fold.
3528  if (newMemRefType == memrefType)
3529  return failure();
3530 
3531  // Create new ViewOp.
3532  auto newViewOp = rewriter.create<ViewOp>(
3533  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3534  viewOp.getByteShift(), newOperands);
3535  // Insert a cast so we have the same type as the old memref type.
3536  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3537  return success();
3538  }
3539 };
3540 
3541 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3543 
3544  LogicalResult matchAndRewrite(ViewOp viewOp,
3545  PatternRewriter &rewriter) const override {
3546  Value memrefOperand = viewOp.getOperand(0);
3547  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3548  if (!memrefCastOp)
3549  return failure();
3550  Value allocOperand = memrefCastOp.getOperand();
3551  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3552  if (!allocOp)
3553  return failure();
3554  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3555  viewOp.getByteShift(),
3556  viewOp.getSizes());
3557  return success();
3558  }
3559 };
3560 
3561 } // namespace
3562 
3563 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3564  MLIRContext *context) {
3565  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3566 }
3567 
3568 //===----------------------------------------------------------------------===//
3569 // AtomicRMWOp
3570 //===----------------------------------------------------------------------===//
3571 
3572 LogicalResult AtomicRMWOp::verify() {
3573  if (getMemRefType().getRank() != getNumOperands() - 2)
3574  return emitOpError(
3575  "expects the number of subscripts to be equal to memref rank");
3576  switch (getKind()) {
3577  case arith::AtomicRMWKind::addf:
3578  case arith::AtomicRMWKind::maximumf:
3579  case arith::AtomicRMWKind::minimumf:
3580  case arith::AtomicRMWKind::mulf:
3581  if (!llvm::isa<FloatType>(getValue().getType()))
3582  return emitOpError() << "with kind '"
3583  << arith::stringifyAtomicRMWKind(getKind())
3584  << "' expects a floating-point type";
3585  break;
3586  case arith::AtomicRMWKind::addi:
3587  case arith::AtomicRMWKind::maxs:
3588  case arith::AtomicRMWKind::maxu:
3589  case arith::AtomicRMWKind::mins:
3590  case arith::AtomicRMWKind::minu:
3591  case arith::AtomicRMWKind::muli:
3592  case arith::AtomicRMWKind::ori:
3593  case arith::AtomicRMWKind::andi:
3594  if (!llvm::isa<IntegerType>(getValue().getType()))
3595  return emitOpError() << "with kind '"
3596  << arith::stringifyAtomicRMWKind(getKind())
3597  << "' expects an integer type";
3598  break;
3599  default:
3600  break;
3601  }
3602  return success();
3603 }
3604 
3605 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3606  /// atomicrmw(memrefcast) -> atomicrmw
3607  if (succeeded(foldMemRefCast(*this, getValue())))
3608  return getResult();
3609  return OpFoldResult();
3610 }
3611 
3612 //===----------------------------------------------------------------------===//
3613 // TableGen'd op method definitions
3614 //===----------------------------------------------------------------------===//
3615 
3616 #define GET_OP_CLASSES
3617 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
Definition: MemRefOps.cpp:163
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
Definition: MemRefOps.cpp:116
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1556
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2136
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:449
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3345
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:430
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2917
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1404
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3072
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
Definition: MemRefOps.cpp:2950
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
Definition: MemRefOps.cpp:176
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1570
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
Definition: MemRefOps.cpp:155
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:944
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3136
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:472
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2928
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:929
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2229
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2424
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:201
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:187
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:140
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:166
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:187
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:177
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:267
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:227
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:865
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:68
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:78
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3104
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:340
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3039
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:419
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:519
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:522
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:479
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:482
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2592
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3288
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3289
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3251
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3252
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.