MLIR  19.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 /// Materialize a single constant operation from a given attribute value with
30 /// the desired resultant type.
32  Attribute value, Type type,
33  Location loc) {
34  return arith::ConstantOp::materialize(builder, value, type, loc);
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Common canonicalization pattern support logic
39 //===----------------------------------------------------------------------===//
40 
41 /// This is a common class used for patterns of the form
42 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
43 /// into the root operation directly.
45  bool folded = false;
46  for (OpOperand &operand : op->getOpOperands()) {
47  auto cast = operand.get().getDefiningOp<CastOp>();
48  if (cast && operand.get() != inner &&
49  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50  operand.set(cast.getOperand());
51  folded = true;
52  }
53  }
54  return success(folded);
55 }
56 
57 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
58 /// type.
60  if (auto memref = llvm::dyn_cast<MemRefType>(type))
61  return RankedTensorType::get(memref.getShape(), memref.getElementType());
62  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
63  return UnrankedTensorType::get(memref.getElementType());
64  return NoneType::get(type.getContext());
65 }
66 
68  int64_t dim) {
69  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that infers the constant values from a list of \p values,
91 /// a \p memRefTy, and another helper function \p getAttributes.
92 /// The inferred constant values replace the related `OpFoldResult` in
93 /// \p values.
94 ///
95 /// \note This function shouldn't be used directly, instead, use the
96 /// `getConstifiedMixedXXX` methods from the related operations.
97 ///
98 /// \p getAttributes retuns a list of potentially constant values, as determined
99 /// by \p isDynamic, from the given \p memRefTy. The returned list must have as
100 /// many elements as \p values or be empty.
101 ///
102 /// E.g., consider the following example:
103 /// ```
104 /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
105 /// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
106 /// ```
107 /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
108 /// Now using this helper function with:
109 /// - `values == [2, %dyn_stride]`,
110 /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
111 /// - `getAttributes == getConstantStrides` (i.e., a wrapper around
112 /// `getStridesAndOffset`), and
113 /// - `isDynamic == ShapedType::isDynamic`
114 /// Will yield: `values == [2, 1]`
116  SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
117  MLIRContext *ctxt,
118  llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
119  llvm::function_ref<bool(int64_t)> isDynamic) {
120  SmallVector<int64_t> constValues = getAttributes(memRefTy);
121  Builder builder(ctxt);
122  for (const auto &it : llvm::enumerate(constValues)) {
123  int64_t constValue = it.value();
124  if (!isDynamic(constValue))
125  values[it.index()] = builder.getIndexAttr(constValue);
126  }
127  for (OpFoldResult &ofr : values) {
128  if (ofr.is<Attribute>()) {
129  // FIXME: We shouldn't need to do that, but right now, the static indices
130  // are created with the wrong type: `i64` instead of `index`.
131  // As a result, if we were to keep the attribute as is, we may fail to see
132  // that two attributes are equal because one would have the i64 type and
133  // the other the index type.
134  // The alternative would be to create constant indices with getI64Attr in
135  // this and the previous loop, but it doesn't logically make sense (we are
136  // dealing with indices here) and would only strenghten the inconsistency
137  // around how static indices are created (some places use getI64Attr,
138  // others use getIndexAttr).
139  // The workaround here is to stick to the IndexAttr type for all the
140  // values, hence we recreate the attribute even when it is already static
141  // to make sure the type is consistent.
142  ofr = builder.getIndexAttr(
143  llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt());
144  continue;
145  }
146  std::optional<int64_t> maybeConstant =
147  getConstantIntValue(ofr.get<Value>());
148  if (maybeConstant)
149  ofr = builder.getIndexAttr(*maybeConstant);
150  }
151 }
152 
153 /// Wrapper around `getShape` that conforms to the function signature
154 /// expected for `getAttributes` in `constifyIndexValues`.
155 static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
156  ArrayRef<int64_t> sizes = memRefTy.getShape();
157  return SmallVector<int64_t>(sizes.begin(), sizes.end());
158 }
159 
160 /// Wrapper around `getStridesAndOffset` that returns only the offset and
161 /// conforms to the function signature expected for `getAttributes` in
162 /// `constifyIndexValues`.
163 static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
164  SmallVector<int64_t> strides;
165  int64_t offset;
166  LogicalResult hasStaticInformation =
167  getStridesAndOffset(memrefType, strides, offset);
168  if (failed(hasStaticInformation))
169  return SmallVector<int64_t>();
170  return SmallVector<int64_t>(1, offset);
171 }
172 
173 /// Wrapper around `getStridesAndOffset` that returns only the strides and
174 /// conforms to the function signature expected for `getAttributes` in
175 /// `constifyIndexValues`.
176 static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
177  SmallVector<int64_t> strides;
178  int64_t offset;
179  LogicalResult hasStaticInformation =
180  getStridesAndOffset(memrefType, strides, offset);
181  if (failed(hasStaticInformation))
182  return SmallVector<int64_t>();
183  return strides;
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // AllocOp / AllocaOp
188 //===----------------------------------------------------------------------===//
189 
190 void AllocOp::getAsmResultNames(
191  function_ref<void(Value, StringRef)> setNameFn) {
192  setNameFn(getResult(), "alloc");
193 }
194 
195 void AllocaOp::getAsmResultNames(
196  function_ref<void(Value, StringRef)> setNameFn) {
197  setNameFn(getResult(), "alloca");
198 }
199 
200 template <typename AllocLikeOp>
201 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
202  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
203  "applies to only alloc or alloca");
204  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
205  if (!memRefType)
206  return op.emitOpError("result must be a memref");
207 
208  if (static_cast<int64_t>(op.getDynamicSizes().size()) !=
209  memRefType.getNumDynamicDims())
210  return op.emitOpError("dimension operand count does not equal memref "
211  "dynamic dimension count");
212 
213  unsigned numSymbols = 0;
214  if (!memRefType.getLayout().isIdentity())
215  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
216  if (op.getSymbolOperands().size() != numSymbols)
217  return op.emitOpError("symbol operand count does not equal memref symbol "
218  "count: expected ")
219  << numSymbols << ", got " << op.getSymbolOperands().size();
220 
221  return success();
222 }
223 
225 
227  // An alloca op needs to have an ancestor with an allocation scope trait.
228  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
229  return emitOpError(
230  "requires an ancestor op with AutomaticAllocationScope trait");
231 
232  return verifyAllocLikeOp(*this);
233 }
234 
235 namespace {
236 /// Fold constant dimensions into an alloc like operation.
237 template <typename AllocLikeOp>
238 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
240 
241  LogicalResult matchAndRewrite(AllocLikeOp alloc,
242  PatternRewriter &rewriter) const override {
243  // Check to see if any dimensions operands are constants. If so, we can
244  // substitute and drop them.
245  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
246  APInt constSizeArg;
247  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
248  return false;
249  return constSizeArg.isNonNegative();
250  }))
251  return failure();
252 
253  auto memrefType = alloc.getType();
254 
255  // Ok, we have one or more constant operands. Collect the non-constant ones
256  // and keep track of the resultant memref type to build.
257  SmallVector<int64_t, 4> newShapeConstants;
258  newShapeConstants.reserve(memrefType.getRank());
259  SmallVector<Value, 4> dynamicSizes;
260 
261  unsigned dynamicDimPos = 0;
262  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
263  int64_t dimSize = memrefType.getDimSize(dim);
264  // If this is already static dimension, keep it.
265  if (!ShapedType::isDynamic(dimSize)) {
266  newShapeConstants.push_back(dimSize);
267  continue;
268  }
269  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
270  APInt constSizeArg;
271  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
272  constSizeArg.isNonNegative()) {
273  // Dynamic shape dimension will be folded.
274  newShapeConstants.push_back(constSizeArg.getZExtValue());
275  } else {
276  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
277  newShapeConstants.push_back(ShapedType::kDynamic);
278  dynamicSizes.push_back(dynamicSize);
279  }
280  dynamicDimPos++;
281  }
282 
283  // Create new memref type (which will have fewer dynamic dimensions).
284  MemRefType newMemRefType =
285  MemRefType::Builder(memrefType).setShape(newShapeConstants);
286  assert(static_cast<int64_t>(dynamicSizes.size()) ==
287  newMemRefType.getNumDynamicDims());
288 
289  // Create and insert the alloc op for the new memref.
290  auto newAlloc = rewriter.create<AllocLikeOp>(
291  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
292  alloc.getAlignmentAttr());
293  // Insert a cast so we have the same type as the old alloc.
294  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
295  return success();
296  }
297 };
298 
299 /// Fold alloc operations with no users or only store and dealloc uses.
300 template <typename T>
301 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
303 
304  LogicalResult matchAndRewrite(T alloc,
305  PatternRewriter &rewriter) const override {
306  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
307  if (auto storeOp = dyn_cast<StoreOp>(op))
308  return storeOp.getValue() == alloc;
309  return !isa<DeallocOp>(op);
310  }))
311  return failure();
312 
313  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
314  rewriter.eraseOp(user);
315 
316  rewriter.eraseOp(alloc);
317  return success();
318  }
319 };
320 } // namespace
321 
322 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
323  MLIRContext *context) {
324  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
325 }
326 
327 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
328  MLIRContext *context) {
329  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
330  context);
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // ReallocOp
335 //===----------------------------------------------------------------------===//
336 
338  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
339  MemRefType resultType = getType();
340 
341  // The source memref should have identity layout (or none).
342  if (!sourceType.getLayout().isIdentity())
343  return emitError("unsupported layout for source memref type ")
344  << sourceType;
345 
346  // The result memref should have identity layout (or none).
347  if (!resultType.getLayout().isIdentity())
348  return emitError("unsupported layout for result memref type ")
349  << resultType;
350 
351  // The source memref and the result memref should be in the same memory space.
352  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
353  return emitError("different memory spaces specified for source memref "
354  "type ")
355  << sourceType << " and result memref type " << resultType;
356 
357  // The source memref and the result memref should have the same element type.
358  if (sourceType.getElementType() != resultType.getElementType())
359  return emitError("different element types specified for source memref "
360  "type ")
361  << sourceType << " and result memref type " << resultType;
362 
363  // Verify that we have the dynamic dimension operand when it is needed.
364  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
365  return emitError("missing dimension operand for result type ")
366  << resultType;
367  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
368  return emitError("unnecessary dimension operand for result type ")
369  << resultType;
370 
371  return success();
372 }
373 
374 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
375  MLIRContext *context) {
376  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // AllocaScopeOp
381 //===----------------------------------------------------------------------===//
382 
384  bool printBlockTerminators = false;
385 
386  p << ' ';
387  if (!getResults().empty()) {
388  p << " -> (" << getResultTypes() << ")";
389  printBlockTerminators = true;
390  }
391  p << ' ';
392  p.printRegion(getBodyRegion(),
393  /*printEntryBlockArgs=*/false,
394  /*printBlockTerminators=*/printBlockTerminators);
395  p.printOptionalAttrDict((*this)->getAttrs());
396 }
397 
399  // Create a region for the body.
400  result.regions.reserve(1);
401  Region *bodyRegion = result.addRegion();
402 
403  // Parse optional results type list.
404  if (parser.parseOptionalArrowTypeList(result.types))
405  return failure();
406 
407  // Parse the body region.
408  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
409  return failure();
410  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
411  result.location);
412 
413  // Parse the optional attribute list.
414  if (parser.parseOptionalAttrDict(result.attributes))
415  return failure();
416 
417  return success();
418 }
419 
420 void AllocaScopeOp::getSuccessorRegions(
422  if (!point.isParent()) {
423  regions.push_back(RegionSuccessor(getResults()));
424  return;
425  }
426 
427  regions.push_back(RegionSuccessor(&getBodyRegion()));
428 }
429 
430 /// Given an operation, return whether this op is guaranteed to
431 /// allocate an AutomaticAllocationScopeResource
433  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
434  if (!interface)
435  return false;
436  for (auto res : op->getResults()) {
437  if (auto effect =
438  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
439  if (isa<SideEffects::AutomaticAllocationScopeResource>(
440  effect->getResource()))
441  return true;
442  }
443  }
444  return false;
445 }
446 
447 /// Given an operation, return whether this op itself could
448 /// allocate an AutomaticAllocationScopeResource. Note that
449 /// this will not check whether an operation contained within
450 /// the op can allocate.
452  // This op itself doesn't create a stack allocation,
453  // the inner allocation should be handled separately.
455  return false;
456  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
457  if (!interface)
458  return true;
459  for (auto res : op->getResults()) {
460  if (auto effect =
461  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
462  if (isa<SideEffects::AutomaticAllocationScopeResource>(
463  effect->getResource()))
464  return true;
465  }
466  }
467  return false;
468 }
469 
470 /// Return whether this op is the last non terminating op
471 /// in a region. That is to say, it is in a one-block region
472 /// and is only followed by a terminator. This prevents
473 /// extending the lifetime of allocations.
475  return op->getNextNode() == op->getBlock()->getTerminator() &&
476  op->getParentRegion()->getBlocks().size() == 1;
477 }
478 
479 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
480 /// or it contains no allocation.
481 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
483 
484  LogicalResult matchAndRewrite(AllocaScopeOp op,
485  PatternRewriter &rewriter) const override {
486  bool hasPotentialAlloca =
487  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
488  if (alloc == op)
489  return WalkResult::advance();
491  return WalkResult::interrupt();
492  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
493  return WalkResult::skip();
494  return WalkResult::advance();
495  }).wasInterrupted();
496 
497  // If this contains no potential allocation, it is always legal to
498  // inline. Otherwise, consider two conditions:
499  if (hasPotentialAlloca) {
500  // If the parent isn't an allocation scope, or we are not the last
501  // non-terminator op in the parent, we will extend the lifetime.
503  return failure();
504  if (!lastNonTerminatorInRegion(op))
505  return failure();
506  }
507 
508  Block *block = &op.getRegion().front();
509  Operation *terminator = block->getTerminator();
510  ValueRange results = terminator->getOperands();
511  rewriter.inlineBlockBefore(block, op);
512  rewriter.replaceOp(op, results);
513  rewriter.eraseOp(terminator);
514  return success();
515  }
516 };
517 
518 /// Move allocations into an allocation scope, if it is legal to
519 /// move them (e.g. their operands are available at the location
520 /// the op would be moved to).
521 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
523 
524  LogicalResult matchAndRewrite(AllocaScopeOp op,
525  PatternRewriter &rewriter) const override {
526 
528  return failure();
529 
530  Operation *lastParentWithoutScope = op->getParentOp();
531 
532  if (!lastParentWithoutScope ||
533  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
534  return failure();
535 
536  // Only apply to if this is this last non-terminator
537  // op in the block (lest lifetime be extended) of a one
538  // block region
539  if (!lastNonTerminatorInRegion(op) ||
540  !lastNonTerminatorInRegion(lastParentWithoutScope))
541  return failure();
542 
543  while (!lastParentWithoutScope->getParentOp()
545  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
546  if (!lastParentWithoutScope ||
547  !lastNonTerminatorInRegion(lastParentWithoutScope))
548  return failure();
549  }
550  assert(lastParentWithoutScope->getParentOp()
552 
553  Region *containingRegion = nullptr;
554  for (auto &r : lastParentWithoutScope->getRegions()) {
555  if (r.isAncestor(op->getParentRegion())) {
556  assert(containingRegion == nullptr &&
557  "only one region can contain the op");
558  containingRegion = &r;
559  }
560  }
561  assert(containingRegion && "op must be contained in a region");
562 
563  SmallVector<Operation *> toHoist;
564  op->walk([&](Operation *alloc) {
566  return WalkResult::skip();
567 
568  // If any operand is not defined before the location of
569  // lastParentWithoutScope (i.e. where we would hoist to), skip.
570  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
571  return containingRegion->isAncestor(v.getParentRegion());
572  }))
573  return WalkResult::skip();
574  toHoist.push_back(alloc);
575  return WalkResult::advance();
576  });
577 
578  if (toHoist.empty())
579  return failure();
580  rewriter.setInsertionPoint(lastParentWithoutScope);
581  for (auto *op : toHoist) {
582  auto *cloned = rewriter.clone(*op);
583  rewriter.replaceOp(op, cloned->getResults());
584  }
585  return success();
586  }
587 };
588 
589 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
590  MLIRContext *context) {
591  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // AssumeAlignmentOp
596 //===----------------------------------------------------------------------===//
597 
599  if (!llvm::isPowerOf2_32(getAlignment()))
600  return emitOpError("alignment must be power of 2");
601  return success();
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // CastOp
606 //===----------------------------------------------------------------------===//
607 
608 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
609  setNameFn(getResult(), "cast");
610 }
611 
612 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
613 /// source memref. This is useful to fold a memref.cast into a consuming op
614 /// and implement canonicalization patterns for ops in different dialects that
615 /// may consume the results of memref.cast operations. Such foldable memref.cast
616 /// operations are typically inserted as `view` and `subview` ops are
617 /// canonicalized, to preserve the type compatibility of their uses.
618 ///
619 /// Returns true when all conditions are met:
620 /// 1. source and result are ranked memrefs with strided semantics and same
621 /// element type and rank.
622 /// 2. each of the source's size, offset or stride has more static information
623 /// than the corresponding result's size, offset or stride.
624 ///
625 /// Example 1:
626 /// ```mlir
627 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
628 /// %2 = consumer %1 ... : memref<?x?xf32> ...
629 /// ```
630 ///
631 /// may fold into:
632 ///
633 /// ```mlir
634 /// %2 = consumer %0 ... : memref<8x16xf32> ...
635 /// ```
636 ///
637 /// Example 2:
638 /// ```
639 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
640 /// to memref<?x?xf32>
641 /// consumer %1 : memref<?x?xf32> ...
642 /// ```
643 ///
644 /// may fold into:
645 ///
646 /// ```
647 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
648 /// ```
649 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
650  MemRefType sourceType =
651  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
652  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
653 
654  // Requires ranked MemRefType.
655  if (!sourceType || !resultType)
656  return false;
657 
658  // Requires same elemental type.
659  if (sourceType.getElementType() != resultType.getElementType())
660  return false;
661 
662  // Requires same rank.
663  if (sourceType.getRank() != resultType.getRank())
664  return false;
665 
666  // Only fold casts between strided memref forms.
667  int64_t sourceOffset, resultOffset;
668  SmallVector<int64_t, 4> sourceStrides, resultStrides;
669  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
670  failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
671  return false;
672 
673  // If cast is towards more static sizes along any dimension, don't fold.
674  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
675  auto ss = std::get<0>(it), st = std::get<1>(it);
676  if (ss != st)
677  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
678  return false;
679  }
680 
681  // If cast is towards more static offset along any dimension, don't fold.
682  if (sourceOffset != resultOffset)
683  if (ShapedType::isDynamic(sourceOffset) &&
684  !ShapedType::isDynamic(resultOffset))
685  return false;
686 
687  // If cast is towards more static strides along any dimension, don't fold.
688  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
689  auto ss = std::get<0>(it), st = std::get<1>(it);
690  if (ss != st)
691  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
692  return false;
693  }
694 
695  return true;
696 }
697 
698 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
699  if (inputs.size() != 1 || outputs.size() != 1)
700  return false;
701  Type a = inputs.front(), b = outputs.front();
702  auto aT = llvm::dyn_cast<MemRefType>(a);
703  auto bT = llvm::dyn_cast<MemRefType>(b);
704 
705  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
706  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
707 
708  if (aT && bT) {
709  if (aT.getElementType() != bT.getElementType())
710  return false;
711  if (aT.getLayout() != bT.getLayout()) {
712  int64_t aOffset, bOffset;
713  SmallVector<int64_t, 4> aStrides, bStrides;
714  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
715  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
716  aStrides.size() != bStrides.size())
717  return false;
718 
719  // Strides along a dimension/offset are compatible if the value in the
720  // source memref is static and the value in the target memref is the
721  // same. They are also compatible if either one is dynamic (see
722  // description of MemRefCastOp for details).
723  auto checkCompatible = [](int64_t a, int64_t b) {
724  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
725  };
726  if (!checkCompatible(aOffset, bOffset))
727  return false;
728  for (const auto &aStride : enumerate(aStrides))
729  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
730  return false;
731  }
732  if (aT.getMemorySpace() != bT.getMemorySpace())
733  return false;
734 
735  // They must have the same rank, and any specified dimensions must match.
736  if (aT.getRank() != bT.getRank())
737  return false;
738 
739  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
740  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
741  if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
742  aDim != bDim)
743  return false;
744  }
745  return true;
746  } else {
747  if (!aT && !uaT)
748  return false;
749  if (!bT && !ubT)
750  return false;
751  // Unranked to unranked casting is unsupported
752  if (uaT && ubT)
753  return false;
754 
755  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
756  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
757  if (aEltType != bEltType)
758  return false;
759 
760  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
761  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
762  return aMemSpace == bMemSpace;
763  }
764 
765  return false;
766 }
767 
768 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
769  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // CopyOp
774 //===----------------------------------------------------------------------===//
775 
776 namespace {
777 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
778 /// and element type, the cast can be skipped. Such CastOps only cast the layout
779 /// of the type.
780 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
782 
783  LogicalResult matchAndRewrite(CopyOp copyOp,
784  PatternRewriter &rewriter) const override {
785  bool modified = false;
786 
787  // Check source.
788  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
789  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
790  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
791 
792  if (fromType && toType) {
793  if (fromType.getShape() == toType.getShape() &&
794  fromType.getElementType() == toType.getElementType()) {
795  rewriter.modifyOpInPlace(copyOp, [&] {
796  copyOp.getSourceMutable().assign(castOp.getSource());
797  });
798  modified = true;
799  }
800  }
801  }
802 
803  // Check target.
804  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
805  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
806  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
807 
808  if (fromType && toType) {
809  if (fromType.getShape() == toType.getShape() &&
810  fromType.getElementType() == toType.getElementType()) {
811  rewriter.modifyOpInPlace(copyOp, [&] {
812  copyOp.getTargetMutable().assign(castOp.getSource());
813  });
814  modified = true;
815  }
816  }
817  }
818 
819  return success(modified);
820  }
821 };
822 
823 /// Fold memref.copy(%x, %x).
824 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
826 
827  LogicalResult matchAndRewrite(CopyOp copyOp,
828  PatternRewriter &rewriter) const override {
829  if (copyOp.getSource() != copyOp.getTarget())
830  return failure();
831 
832  rewriter.eraseOp(copyOp);
833  return success();
834  }
835 };
836 } // namespace
837 
838 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
839  MLIRContext *context) {
840  results.add<FoldCopyOfCast, FoldSelfCopy>(context);
841 }
842 
843 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
845  /// copy(memrefcast) -> copy
846  bool folded = false;
847  Operation *op = *this;
848  for (OpOperand &operand : op->getOpOperands()) {
849  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
850  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
851  operand.set(castOp.getOperand());
852  folded = true;
853  }
854  }
855  return success(folded);
856 }
857 
858 //===----------------------------------------------------------------------===//
859 // DeallocOp
860 //===----------------------------------------------------------------------===//
861 
862 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
864  /// dealloc(memrefcast) -> dealloc
865  return foldMemRefCast(*this);
866 }
867 
868 //===----------------------------------------------------------------------===//
869 // DimOp
870 //===----------------------------------------------------------------------===//
871 
872 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
873  setNameFn(getResult(), "dim");
874 }
875 
876 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
877  int64_t index) {
878  auto loc = result.location;
879  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
880  build(builder, result, source, indexValue);
881 }
882 
883 std::optional<int64_t> DimOp::getConstantIndex() {
884  return getConstantIntValue(getIndex());
885 }
886 
887 Speculation::Speculatability DimOp::getSpeculatability() {
888  auto constantIndex = getConstantIndex();
889  if (!constantIndex)
891 
892  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
893  if (!rankedSourceType)
895 
896  if (rankedSourceType.getRank() <= constantIndex)
898 
900 }
901 
902 /// Return a map with key being elements in `vals` and data being number of
903 /// occurences of it. Use std::map, since the `vals` here are strides and the
904 /// dynamic stride value is the same as the tombstone value for
905 /// `DenseMap<int64_t>`.
906 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
907  std::map<int64_t, unsigned> numOccurences;
908  for (auto val : vals)
909  numOccurences[val]++;
910  return numOccurences;
911 }
912 
913 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
914 /// to be a subset of `originalType` with some `1` entries erased, return the
915 /// set of indices that specifies which of the entries of `originalShape` are
916 /// dropped to obtain `reducedShape`.
917 /// This accounts for cases where there are multiple unit-dims, but only a
918 /// subset of those are dropped. For MemRefTypes these can be disambiguated
919 /// using the strides. If a dimension is dropped the stride must be dropped too.
921 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
922  ArrayRef<OpFoldResult> sizes) {
923  llvm::SmallBitVector unusedDims(originalType.getRank());
924  if (originalType.getRank() == reducedType.getRank())
925  return unusedDims;
926 
927  for (const auto &dim : llvm::enumerate(sizes))
928  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
929  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
930  unusedDims.set(dim.index());
931 
932  // Early exit for the case where the number of unused dims matches the number
933  // of ranks reduced.
934  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
935  originalType.getRank())
936  return unusedDims;
937 
938  SmallVector<int64_t> originalStrides, candidateStrides;
939  int64_t originalOffset, candidateOffset;
940  if (failed(
941  getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
942  failed(
943  getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
944  return failure();
945 
946  // For memrefs, a dimension is truly dropped if its corresponding stride is
947  // also dropped. This is particularly important when more than one of the dims
948  // is 1. Track the number of occurences of the strides in the original type
949  // and the candidate type. For each unused dim that stride should not be
950  // present in the candidate type. Note that there could be multiple dimensions
951  // that have the same size. We dont need to exactly figure out which dim
952  // corresponds to which stride, we just need to verify that the number of
953  // reptitions of a stride in the original + number of unused dims with that
954  // stride == number of repititions of a stride in the candidate.
955  std::map<int64_t, unsigned> currUnaccountedStrides =
956  getNumOccurences(originalStrides);
957  std::map<int64_t, unsigned> candidateStridesNumOccurences =
958  getNumOccurences(candidateStrides);
959  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
960  if (!unusedDims.test(dim))
961  continue;
962  int64_t originalStride = originalStrides[dim];
963  if (currUnaccountedStrides[originalStride] >
964  candidateStridesNumOccurences[originalStride]) {
965  // This dim can be treated as dropped.
966  currUnaccountedStrides[originalStride]--;
967  continue;
968  }
969  if (currUnaccountedStrides[originalStride] ==
970  candidateStridesNumOccurences[originalStride]) {
971  // The stride for this is not dropped. Keep as is.
972  unusedDims.reset(dim);
973  continue;
974  }
975  if (currUnaccountedStrides[originalStride] <
976  candidateStridesNumOccurences[originalStride]) {
977  // This should never happen. Cant have a stride in the reduced rank type
978  // that wasnt in the original one.
979  return failure();
980  }
981  }
982 
983  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
984  originalType.getRank())
985  return failure();
986  return unusedDims;
987 }
988 
989 llvm::SmallBitVector SubViewOp::getDroppedDims() {
990  MemRefType sourceType = getSourceType();
991  MemRefType resultType = getType();
993  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
994  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
995  return *unusedDims;
996 }
997 
998 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
999  // All forms of folding require a known index.
1000  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1001  if (!index)
1002  return {};
1003 
1004  // Folding for unranked types (UnrankedMemRefType) is not supported.
1005  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1006  if (!memrefType)
1007  return {};
1008 
1009  // Out of bound indices produce undefined behavior but are still valid IR.
1010  // Don't choke on them.
1011  int64_t indexVal = index.getInt();
1012  if (indexVal < 0 || indexVal >= memrefType.getRank())
1013  return {};
1014 
1015  // Fold if the shape extent along the given index is known.
1016  if (!memrefType.isDynamicDim(index.getInt())) {
1017  Builder builder(getContext());
1018  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1019  }
1020 
1021  // The size at the given index is now known to be a dynamic size.
1022  unsigned unsignedIndex = index.getValue().getZExtValue();
1023 
1024  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1025  Operation *definingOp = getSource().getDefiningOp();
1026 
1027  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1028  return *(alloc.getDynamicSizes().begin() +
1029  memrefType.getDynamicDimIndex(unsignedIndex));
1030 
1031  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1032  return *(alloca.getDynamicSizes().begin() +
1033  memrefType.getDynamicDimIndex(unsignedIndex));
1034 
1035  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1036  return *(view.getDynamicSizes().begin() +
1037  memrefType.getDynamicDimIndex(unsignedIndex));
1038 
1039  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1040  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1041  unsigned resultIndex = 0;
1042  unsigned sourceRank = subview.getSourceType().getRank();
1043  unsigned sourceIndex = 0;
1044  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1045  if (unusedDims.test(i))
1046  continue;
1047  if (resultIndex == unsignedIndex) {
1048  sourceIndex = i;
1049  break;
1050  }
1051  resultIndex++;
1052  }
1053  assert(subview.isDynamicSize(sourceIndex) &&
1054  "expected dynamic subview size");
1055  return subview.getDynamicSize(sourceIndex);
1056  }
1057 
1058  if (auto sizeInterface =
1059  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1060  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1061  "Expected dynamic subview size");
1062  return sizeInterface.getDynamicSize(unsignedIndex);
1063  }
1064 
1065  // dim(memrefcast) -> dim
1066  if (succeeded(foldMemRefCast(*this)))
1067  return getResult();
1068 
1069  return {};
1070 }
1071 
1072 namespace {
1073 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1074 /// operand.
1075 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1077 
1078  LogicalResult matchAndRewrite(DimOp dim,
1079  PatternRewriter &rewriter) const override {
1080  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1081 
1082  if (!reshape)
1083  return rewriter.notifyMatchFailure(
1084  dim, "Dim op is not defined by a reshape op.");
1085 
1086  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1087  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1088  // cheaply check that either of the following conditions hold:
1089  // 1. dim.getIndex() is defined in the same block as reshape but before
1090  // reshape.
1091  // 2. dim.getIndex() is defined in a parent block of
1092  // reshape.
1093 
1094  // Check condition 1
1095  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1096  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1097  if (reshape->isBeforeInBlock(definingOp)) {
1098  return rewriter.notifyMatchFailure(
1099  dim,
1100  "dim.getIndex is not defined before reshape in the same block.");
1101  }
1102  } // else dim.getIndex is a block argument to reshape->getBlock and
1103  // dominates reshape
1104  } // Check condition 2
1105  else if (dim->getBlock() != reshape->getBlock() &&
1106  !dim.getIndex().getParentRegion()->isProperAncestor(
1107  reshape->getParentRegion())) {
1108  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1109  // already know dim.getIndex() dominates reshape without calling
1110  // `isProperAncestor`
1111  return rewriter.notifyMatchFailure(
1112  dim, "dim.getIndex does not dominate reshape.");
1113  }
1114 
1115  // Place the load directly after the reshape to ensure that the shape memref
1116  // was not mutated.
1117  rewriter.setInsertionPointAfter(reshape);
1118  Location loc = dim.getLoc();
1119  Value load =
1120  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1121  if (load.getType() != dim.getType())
1122  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1123  rewriter.replaceOp(dim, load);
1124  return success();
1125  }
1126 };
1127 
1128 } // namespace
1129 
1130 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1131  MLIRContext *context) {
1132  results.add<DimOfMemRefReshape>(context);
1133 }
1134 
1135 // ---------------------------------------------------------------------------
1136 // DmaStartOp
1137 // ---------------------------------------------------------------------------
1138 
1139 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1140  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1141  ValueRange destIndices, Value numElements,
1142  Value tagMemRef, ValueRange tagIndices, Value stride,
1143  Value elementsPerStride) {
1144  result.addOperands(srcMemRef);
1145  result.addOperands(srcIndices);
1146  result.addOperands(destMemRef);
1147  result.addOperands(destIndices);
1148  result.addOperands({numElements, tagMemRef});
1149  result.addOperands(tagIndices);
1150  if (stride)
1151  result.addOperands({stride, elementsPerStride});
1152 }
1153 
1155  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1156  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1157  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1158  if (isStrided())
1159  p << ", " << getStride() << ", " << getNumElementsPerStride();
1160 
1161  p.printOptionalAttrDict((*this)->getAttrs());
1162  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1163  << ", " << getTagMemRef().getType();
1164 }
1165 
1166 // Parse DmaStartOp.
1167 // Ex:
1168 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1169 // %tag[%index], %stride, %num_elt_per_stride :
1170 // : memref<3076 x f32, 0>,
1171 // memref<1024 x f32, 2>,
1172 // memref<1 x i32>
1173 //
1175  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1177  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1179  OpAsmParser::UnresolvedOperand numElementsInfo;
1180  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1183 
1184  SmallVector<Type, 3> types;
1185  auto indexType = parser.getBuilder().getIndexType();
1186 
1187  // Parse and resolve the following list of operands:
1188  // *) source memref followed by its indices (in square brackets).
1189  // *) destination memref followed by its indices (in square brackets).
1190  // *) dma size in KiB.
1191  if (parser.parseOperand(srcMemRefInfo) ||
1192  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1193  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1194  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1195  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1196  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1197  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1198  return failure();
1199 
1200  // Parse optional stride and elements per stride.
1201  if (parser.parseTrailingOperandList(strideInfo))
1202  return failure();
1203 
1204  bool isStrided = strideInfo.size() == 2;
1205  if (!strideInfo.empty() && !isStrided) {
1206  return parser.emitError(parser.getNameLoc(),
1207  "expected two stride related operands");
1208  }
1209 
1210  if (parser.parseColonTypeList(types))
1211  return failure();
1212  if (types.size() != 3)
1213  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1214 
1215  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1216  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1217  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1218  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1219  // size should be an index.
1220  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1221  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1222  // tag indices should be index.
1223  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1224  return failure();
1225 
1226  if (isStrided) {
1227  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1228  return failure();
1229  }
1230 
1231  return success();
1232 }
1233 
1235  unsigned numOperands = getNumOperands();
1236 
1237  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1238  // the number of elements.
1239  if (numOperands < 4)
1240  return emitOpError("expected at least 4 operands");
1241 
1242  // Check types of operands. The order of these calls is important: the later
1243  // calls rely on some type properties to compute the operand position.
1244  // 1. Source memref.
1245  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1246  return emitOpError("expected source to be of memref type");
1247  if (numOperands < getSrcMemRefRank() + 4)
1248  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1249  << " operands";
1250  if (!getSrcIndices().empty() &&
1251  !llvm::all_of(getSrcIndices().getTypes(),
1252  [](Type t) { return t.isIndex(); }))
1253  return emitOpError("expected source indices to be of index type");
1254 
1255  // 2. Destination memref.
1256  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1257  return emitOpError("expected destination to be of memref type");
1258  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1259  if (numOperands < numExpectedOperands)
1260  return emitOpError() << "expected at least " << numExpectedOperands
1261  << " operands";
1262  if (!getDstIndices().empty() &&
1263  !llvm::all_of(getDstIndices().getTypes(),
1264  [](Type t) { return t.isIndex(); }))
1265  return emitOpError("expected destination indices to be of index type");
1266 
1267  // 3. Number of elements.
1268  if (!getNumElements().getType().isIndex())
1269  return emitOpError("expected num elements to be of index type");
1270 
1271  // 4. Tag memref.
1272  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1273  return emitOpError("expected tag to be of memref type");
1274  numExpectedOperands += getTagMemRefRank();
1275  if (numOperands < numExpectedOperands)
1276  return emitOpError() << "expected at least " << numExpectedOperands
1277  << " operands";
1278  if (!getTagIndices().empty() &&
1279  !llvm::all_of(getTagIndices().getTypes(),
1280  [](Type t) { return t.isIndex(); }))
1281  return emitOpError("expected tag indices to be of index type");
1282 
1283  // Optional stride-related operands must be either both present or both
1284  // absent.
1285  if (numOperands != numExpectedOperands &&
1286  numOperands != numExpectedOperands + 2)
1287  return emitOpError("incorrect number of operands");
1288 
1289  // 5. Strides.
1290  if (isStrided()) {
1291  if (!getStride().getType().isIndex() ||
1292  !getNumElementsPerStride().getType().isIndex())
1293  return emitOpError(
1294  "expected stride and num elements per stride to be of type index");
1295  }
1296 
1297  return success();
1298 }
1299 
1300 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1301  SmallVectorImpl<OpFoldResult> &results) {
1302  /// dma_start(memrefcast) -> dma_start
1303  return foldMemRefCast(*this);
1304 }
1305 
1306 // ---------------------------------------------------------------------------
1307 // DmaWaitOp
1308 // ---------------------------------------------------------------------------
1309 
1310 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1311  SmallVectorImpl<OpFoldResult> &results) {
1312  /// dma_wait(memrefcast) -> dma_wait
1313  return foldMemRefCast(*this);
1314 }
1315 
1317  // Check that the number of tag indices matches the tagMemRef rank.
1318  unsigned numTagIndices = getTagIndices().size();
1319  unsigned tagMemRefRank = getTagMemRefRank();
1320  if (numTagIndices != tagMemRefRank)
1321  return emitOpError() << "expected tagIndices to have the same number of "
1322  "elements as the tagMemRef rank, expected "
1323  << tagMemRefRank << ", but got " << numTagIndices;
1324  return success();
1325 }
1326 
1327 //===----------------------------------------------------------------------===//
1328 // ExtractAlignedPointerAsIndexOp
1329 //===----------------------------------------------------------------------===//
1330 
1331 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1332  function_ref<void(Value, StringRef)> setNameFn) {
1333  setNameFn(getResult(), "intptr");
1334 }
1335 
1336 //===----------------------------------------------------------------------===//
1337 // ExtractStridedMetadataOp
1338 //===----------------------------------------------------------------------===//
1339 
1340 /// The number and type of the results are inferred from the
1341 /// shape of the source.
1342 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1343  MLIRContext *context, std::optional<Location> location,
1344  ExtractStridedMetadataOp::Adaptor adaptor,
1345  SmallVectorImpl<Type> &inferredReturnTypes) {
1346  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1347  if (!sourceType)
1348  return failure();
1349 
1350  unsigned sourceRank = sourceType.getRank();
1351  IndexType indexType = IndexType::get(context);
1352  auto memrefType =
1353  MemRefType::get({}, sourceType.getElementType(),
1354  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1355  // Base.
1356  inferredReturnTypes.push_back(memrefType);
1357  // Offset.
1358  inferredReturnTypes.push_back(indexType);
1359  // Sizes and strides.
1360  for (unsigned i = 0; i < sourceRank * 2; ++i)
1361  inferredReturnTypes.push_back(indexType);
1362  return success();
1363 }
1364 
1365 void ExtractStridedMetadataOp::getAsmResultNames(
1366  function_ref<void(Value, StringRef)> setNameFn) {
1367  setNameFn(getBaseBuffer(), "base_buffer");
1368  setNameFn(getOffset(), "offset");
1369  // For multi-result to work properly with pretty names and packed syntax `x:3`
1370  // we can only give a pretty name to the first value in the pack.
1371  if (!getSizes().empty()) {
1372  setNameFn(getSizes().front(), "sizes");
1373  setNameFn(getStrides().front(), "strides");
1374  }
1375 }
1376 
1377 /// Helper function to perform the replacement of all constant uses of `values`
1378 /// by a materialized constant extracted from `maybeConstants`.
1379 /// `values` and `maybeConstants` are expected to have the same size.
1380 template <typename Container>
1381 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1382  Container values,
1383  ArrayRef<OpFoldResult> maybeConstants) {
1384  assert(values.size() == maybeConstants.size() &&
1385  " expected values and maybeConstants of the same size");
1386  bool atLeastOneReplacement = false;
1387  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1388  // Don't materialize a constant if there are no uses: this would indice
1389  // infinite loops in the driver.
1390  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1391  continue;
1392  assert(maybeConstant.template is<Attribute>() &&
1393  "The constified value should be either unchanged (i.e., == result) "
1394  "or a constant");
1395  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1396  loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
1397  .getInt());
1398  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1399  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1400  // yet.
1401  op->replaceUsesOfWith(result, constantVal);
1402  atLeastOneReplacement = true;
1403  }
1404  }
1405  return atLeastOneReplacement;
1406 }
1407 
1409 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1410  SmallVectorImpl<OpFoldResult> &results) {
1411  OpBuilder builder(*this);
1412 
1413  bool atLeastOneReplacement = replaceConstantUsesOf(
1414  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1415  getConstifiedMixedOffset());
1416  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1417  getConstifiedMixedSizes());
1418  atLeastOneReplacement |= replaceConstantUsesOf(
1419  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1420 
1421  return success(atLeastOneReplacement);
1422 }
1423 
1424 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1425  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1426  constifyIndexValues(values, getSource().getType(), getContext(),
1427  getConstantSizes, ShapedType::isDynamic);
1428  return values;
1429 }
1430 
1432 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1433  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1434  constifyIndexValues(values, getSource().getType(), getContext(),
1435  getConstantStrides, ShapedType::isDynamic);
1436  return values;
1437 }
1438 
1439 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1440  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1441  SmallVector<OpFoldResult> values(1, offsetOfr);
1442  constifyIndexValues(values, getSource().getType(), getContext(),
1443  getConstantOffset, ShapedType::isDynamic);
1444  return values[0];
1445 }
1446 
1447 //===----------------------------------------------------------------------===//
1448 // GenericAtomicRMWOp
1449 //===----------------------------------------------------------------------===//
1450 
1451 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1452  Value memref, ValueRange ivs) {
1453  OpBuilder::InsertionGuard g(builder);
1454  result.addOperands(memref);
1455  result.addOperands(ivs);
1456 
1457  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1458  Type elementType = memrefType.getElementType();
1459  result.addTypes(elementType);
1460 
1461  Region *bodyRegion = result.addRegion();
1462  builder.createBlock(bodyRegion);
1463  bodyRegion->addArgument(elementType, memref.getLoc());
1464  }
1465 }
1466 
1468  auto &body = getRegion();
1469  if (body.getNumArguments() != 1)
1470  return emitOpError("expected single number of entry block arguments");
1471 
1472  if (getResult().getType() != body.getArgument(0).getType())
1473  return emitOpError("expected block argument of the same type result type");
1474 
1475  bool hasSideEffects =
1476  body.walk([&](Operation *nestedOp) {
1477  if (isMemoryEffectFree(nestedOp))
1478  return WalkResult::advance();
1479  nestedOp->emitError(
1480  "body of 'memref.generic_atomic_rmw' should contain "
1481  "only operations with no side effects");
1482  return WalkResult::interrupt();
1483  })
1484  .wasInterrupted();
1485  return hasSideEffects ? failure() : success();
1486 }
1487 
1489  OperationState &result) {
1491  Type memrefType;
1493 
1494  Type indexType = parser.getBuilder().getIndexType();
1495  if (parser.parseOperand(memref) ||
1497  parser.parseColonType(memrefType) ||
1498  parser.resolveOperand(memref, memrefType, result.operands) ||
1499  parser.resolveOperands(ivs, indexType, result.operands))
1500  return failure();
1501 
1502  Region *body = result.addRegion();
1503  if (parser.parseRegion(*body, {}) ||
1504  parser.parseOptionalAttrDict(result.attributes))
1505  return failure();
1506  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1507  return success();
1508 }
1509 
1511  p << ' ' << getMemref() << "[" << getIndices()
1512  << "] : " << getMemref().getType() << ' ';
1513  p.printRegion(getRegion());
1514  p.printOptionalAttrDict((*this)->getAttrs());
1515 }
1516 
1517 //===----------------------------------------------------------------------===//
1518 // AtomicYieldOp
1519 //===----------------------------------------------------------------------===//
1520 
1522  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1523  Type resultType = getResult().getType();
1524  if (parentType != resultType)
1525  return emitOpError() << "types mismatch between yield op: " << resultType
1526  << " and its parent: " << parentType;
1527  return success();
1528 }
1529 
1530 //===----------------------------------------------------------------------===//
1531 // GlobalOp
1532 //===----------------------------------------------------------------------===//
1533 
1535  TypeAttr type,
1536  Attribute initialValue) {
1537  p << type;
1538  if (!op.isExternal()) {
1539  p << " = ";
1540  if (op.isUninitialized())
1541  p << "uninitialized";
1542  else
1543  p.printAttributeWithoutType(initialValue);
1544  }
1545 }
1546 
1547 static ParseResult
1549  Attribute &initialValue) {
1550  Type type;
1551  if (parser.parseType(type))
1552  return failure();
1553 
1554  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1555  if (!memrefType || !memrefType.hasStaticShape())
1556  return parser.emitError(parser.getNameLoc())
1557  << "type should be static shaped memref, but got " << type;
1558  typeAttr = TypeAttr::get(type);
1559 
1560  if (parser.parseOptionalEqual())
1561  return success();
1562 
1563  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1564  initialValue = UnitAttr::get(parser.getContext());
1565  return success();
1566  }
1567 
1568  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1569  if (parser.parseAttribute(initialValue, tensorType))
1570  return failure();
1571  if (!llvm::isa<ElementsAttr>(initialValue))
1572  return parser.emitError(parser.getNameLoc())
1573  << "initial value should be a unit or elements attribute";
1574  return success();
1575 }
1576 
1578  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1579  if (!memrefType || !memrefType.hasStaticShape())
1580  return emitOpError("type should be static shaped memref, but got ")
1581  << getType();
1582 
1583  // Verify that the initial value, if present, is either a unit attribute or
1584  // an elements attribute.
1585  if (getInitialValue().has_value()) {
1586  Attribute initValue = getInitialValue().value();
1587  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1588  return emitOpError("initial value should be a unit or elements "
1589  "attribute, but got ")
1590  << initValue;
1591 
1592  // Check that the type of the initial value is compatible with the type of
1593  // the global variable.
1594  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1595  Type initType = elementsAttr.getType();
1596  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1597  if (initType != tensorType)
1598  return emitOpError("initial value expected to be of type ")
1599  << tensorType << ", but was of type " << initType;
1600  }
1601  }
1602 
1603  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1604  uint64_t alignment = *alignAttr;
1605 
1606  if (!llvm::isPowerOf2_64(alignment))
1607  return emitError() << "alignment attribute value " << alignment
1608  << " is not a power of 2";
1609  }
1610 
1611  // TODO: verify visibility for declarations.
1612  return success();
1613 }
1614 
1615 ElementsAttr GlobalOp::getConstantInitValue() {
1616  auto initVal = getInitialValue();
1617  if (getConstant() && initVal.has_value())
1618  return llvm::cast<ElementsAttr>(initVal.value());
1619  return {};
1620 }
1621 
1622 //===----------------------------------------------------------------------===//
1623 // GetGlobalOp
1624 //===----------------------------------------------------------------------===//
1625 
1627 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1628  // Verify that the result type is same as the type of the referenced
1629  // memref.global op.
1630  auto global =
1631  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1632  if (!global)
1633  return emitOpError("'")
1634  << getName() << "' does not reference a valid global memref";
1635 
1636  Type resultType = getResult().getType();
1637  if (global.getType() != resultType)
1638  return emitOpError("result type ")
1639  << resultType << " does not match type " << global.getType()
1640  << " of the global memref @" << getName();
1641  return success();
1642 }
1643 
1644 //===----------------------------------------------------------------------===//
1645 // LoadOp
1646 //===----------------------------------------------------------------------===//
1647 
1649  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1650  return emitOpError("incorrect number of indices for load, expected ")
1651  << getMemRefType().getRank() << " but got " << getIndices().size();
1652  }
1653  return success();
1654 }
1655 
1656 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1657  /// load(memrefcast) -> load
1658  if (succeeded(foldMemRefCast(*this)))
1659  return getResult();
1660  return OpFoldResult();
1661 }
1662 
1663 //===----------------------------------------------------------------------===//
1664 // MemorySpaceCastOp
1665 //===----------------------------------------------------------------------===//
1666 
1667 void MemorySpaceCastOp::getAsmResultNames(
1668  function_ref<void(Value, StringRef)> setNameFn) {
1669  setNameFn(getResult(), "memspacecast");
1670 }
1671 
1672 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1673  if (inputs.size() != 1 || outputs.size() != 1)
1674  return false;
1675  Type a = inputs.front(), b = outputs.front();
1676  auto aT = llvm::dyn_cast<MemRefType>(a);
1677  auto bT = llvm::dyn_cast<MemRefType>(b);
1678 
1679  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1680  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1681 
1682  if (aT && bT) {
1683  if (aT.getElementType() != bT.getElementType())
1684  return false;
1685  if (aT.getLayout() != bT.getLayout())
1686  return false;
1687  if (aT.getShape() != bT.getShape())
1688  return false;
1689  return true;
1690  }
1691  if (uaT && ubT) {
1692  return uaT.getElementType() == ubT.getElementType();
1693  }
1694  return false;
1695 }
1696 
1697 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1698  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1699  // t2)
1700  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1701  getSourceMutable().assign(parentCast.getSource());
1702  return getResult();
1703  }
1704  return Value{};
1705 }
1706 
1707 //===----------------------------------------------------------------------===//
1708 // PrefetchOp
1709 //===----------------------------------------------------------------------===//
1710 
1712  p << " " << getMemref() << '[';
1714  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1715  p << ", locality<" << getLocalityHint();
1716  p << ">, " << (getIsDataCache() ? "data" : "instr");
1718  (*this)->getAttrs(),
1719  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1720  p << " : " << getMemRefType();
1721 }
1722 
1724  OpAsmParser::UnresolvedOperand memrefInfo;
1726  IntegerAttr localityHint;
1727  MemRefType type;
1728  StringRef readOrWrite, cacheType;
1729 
1730  auto indexTy = parser.getBuilder().getIndexType();
1731  auto i32Type = parser.getBuilder().getIntegerType(32);
1732  if (parser.parseOperand(memrefInfo) ||
1733  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1734  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1735  parser.parseComma() || parser.parseKeyword("locality") ||
1736  parser.parseLess() ||
1737  parser.parseAttribute(localityHint, i32Type, "localityHint",
1738  result.attributes) ||
1739  parser.parseGreater() || parser.parseComma() ||
1740  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1741  parser.resolveOperand(memrefInfo, type, result.operands) ||
1742  parser.resolveOperands(indexInfo, indexTy, result.operands))
1743  return failure();
1744 
1745  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1746  return parser.emitError(parser.getNameLoc(),
1747  "rw specifier has to be 'read' or 'write'");
1748  result.addAttribute(
1749  PrefetchOp::getIsWriteAttrStrName(),
1750  parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1751 
1752  if (!cacheType.equals("data") && !cacheType.equals("instr"))
1753  return parser.emitError(parser.getNameLoc(),
1754  "cache type has to be 'data' or 'instr'");
1755 
1756  result.addAttribute(
1757  PrefetchOp::getIsDataCacheAttrStrName(),
1758  parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1759 
1760  return success();
1761 }
1762 
1764  if (getNumOperands() != 1 + getMemRefType().getRank())
1765  return emitOpError("too few indices");
1766 
1767  return success();
1768 }
1769 
1770 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1771  SmallVectorImpl<OpFoldResult> &results) {
1772  // prefetch(memrefcast) -> prefetch
1773  return foldMemRefCast(*this);
1774 }
1775 
1776 //===----------------------------------------------------------------------===//
1777 // RankOp
1778 //===----------------------------------------------------------------------===//
1779 
1780 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1781  // Constant fold rank when the rank of the operand is known.
1782  auto type = getOperand().getType();
1783  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1784  if (shapedType && shapedType.hasRank())
1785  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1786  return IntegerAttr();
1787 }
1788 
1789 //===----------------------------------------------------------------------===//
1790 // ReinterpretCastOp
1791 //===----------------------------------------------------------------------===//
1792 
1793 void ReinterpretCastOp::getAsmResultNames(
1794  function_ref<void(Value, StringRef)> setNameFn) {
1795  setNameFn(getResult(), "reinterpret_cast");
1796 }
1797 
1798 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1799 /// `staticSizes` and `staticStrides` are automatically filled with
1800 /// source-memref-rank sentinel values that encode dynamic entries.
1801 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1802  MemRefType resultType, Value source,
1803  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1804  ArrayRef<OpFoldResult> strides,
1805  ArrayRef<NamedAttribute> attrs) {
1806  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1807  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1808  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1809  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1810  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1811  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1812  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1813  b.getDenseI64ArrayAttr(staticSizes),
1814  b.getDenseI64ArrayAttr(staticStrides));
1815  result.addAttributes(attrs);
1816 }
1817 
1818 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1819  MemRefType resultType, Value source,
1820  int64_t offset, ArrayRef<int64_t> sizes,
1821  ArrayRef<int64_t> strides,
1822  ArrayRef<NamedAttribute> attrs) {
1823  SmallVector<OpFoldResult> sizeValues =
1824  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1825  return b.getI64IntegerAttr(v);
1826  }));
1827  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1828  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1829  return b.getI64IntegerAttr(v);
1830  }));
1831  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1832  strideValues, attrs);
1833 }
1834 
1835 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1836  MemRefType resultType, Value source, Value offset,
1837  ValueRange sizes, ValueRange strides,
1838  ArrayRef<NamedAttribute> attrs) {
1839  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1840  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1841  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1842  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1843  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1844 }
1845 
1846 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1847 // completed automatically, like we have for subview and extract_slice.
1849  // The source and result memrefs should be in the same memory space.
1850  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1851  auto resultType = llvm::cast<MemRefType>(getType());
1852  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1853  return emitError("different memory spaces specified for source type ")
1854  << srcType << " and result memref type " << resultType;
1855  if (srcType.getElementType() != resultType.getElementType())
1856  return emitError("different element types specified for source type ")
1857  << srcType << " and result memref type " << resultType;
1858 
1859  // Match sizes in result memref type and in static_sizes attribute.
1860  for (auto [idx, resultSize, expectedSize] :
1861  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1862  if (!ShapedType::isDynamic(resultSize) &&
1863  !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1864  return emitError("expected result type with size = ")
1865  << expectedSize << " instead of " << resultSize
1866  << " in dim = " << idx;
1867  }
1868 
1869  // Match offset and strides in static_offset and static_strides attributes. If
1870  // result memref type has no affine map specified, this will assume an
1871  // identity layout.
1872  int64_t resultOffset;
1873  SmallVector<int64_t, 4> resultStrides;
1874  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1875  return emitError("expected result type to have strided layout but found ")
1876  << resultType;
1877 
1878  // Match offset in result memref type and in static_offsets attribute.
1879  int64_t expectedOffset = getStaticOffsets().front();
1880  if (!ShapedType::isDynamic(resultOffset) &&
1881  !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
1882  return emitError("expected result type with offset = ")
1883  << expectedOffset << " instead of " << resultOffset;
1884 
1885  // Match strides in result memref type and in static_strides attribute.
1886  for (auto [idx, resultStride, expectedStride] :
1887  llvm::enumerate(resultStrides, getStaticStrides())) {
1888  if (!ShapedType::isDynamic(resultStride) &&
1889  !ShapedType::isDynamic(expectedStride) &&
1890  resultStride != expectedStride)
1891  return emitError("expected result type with stride = ")
1892  << expectedStride << " instead of " << resultStride
1893  << " in dim = " << idx;
1894  }
1895 
1896  return success();
1897 }
1898 
1899 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1900  Value src = getSource();
1901  auto getPrevSrc = [&]() -> Value {
1902  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1903  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1904  return prev.getSource();
1905 
1906  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1907  if (auto prev = src.getDefiningOp<CastOp>())
1908  return prev.getSource();
1909 
1910  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1911  // are 0.
1912  if (auto prev = src.getDefiningOp<SubViewOp>())
1913  if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1914  return isConstantIntValue(val, 0);
1915  }))
1916  return prev.getSource();
1917 
1918  return nullptr;
1919  };
1920 
1921  if (auto prevSrc = getPrevSrc()) {
1922  getSourceMutable().assign(prevSrc);
1923  return getResult();
1924  }
1925 
1926  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1927  if (!ShapedType::isDynamicShape(getType().getShape()) &&
1928  src.getType() == getType() && getStaticOffsets().front() == 0) {
1929  return src;
1930  }
1931 
1932  return nullptr;
1933 }
1934 
1935 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1937  constifyIndexValues(values, getType(), getContext(), getConstantSizes,
1938  ShapedType::isDynamic);
1939  return values;
1940 }
1941 
1942 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1943  SmallVector<OpFoldResult> values = getMixedStrides();
1944  constifyIndexValues(values, getType(), getContext(), getConstantStrides,
1945  ShapedType::isDynamic);
1946  return values;
1947 }
1948 
1949 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1950  SmallVector<OpFoldResult> values = getMixedOffsets();
1951  assert(values.size() == 1 &&
1952  "reinterpret_cast must have one and only one offset");
1953  constifyIndexValues(values, getType(), getContext(), getConstantOffset,
1954  ShapedType::isDynamic);
1955  return values[0];
1956 }
1957 
1958 namespace {
1959 /// Replace the sequence:
1960 /// ```
1961 /// base, offset, sizes, strides = extract_strided_metadata src
1962 /// dst = reinterpret_cast base to offset, sizes, strides
1963 /// ```
1964 /// With
1965 ///
1966 /// ```
1967 /// dst = memref.cast src
1968 /// ```
1969 ///
1970 /// Note: The cast operation is only inserted when the type of dst and src
1971 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
1972 ///
1973 /// This pattern also matches when the offset, sizes, and strides don't come
1974 /// directly from the `extract_strided_metadata`'s results but it can be
1975 /// statically proven that they would hold the same values.
1976 ///
1977 /// For instance, the following sequence would be replaced:
1978 /// ```
1979 /// base, offset, sizes, strides =
1980 /// extract_strided_metadata memref : memref<3x4xty>
1981 /// dst = reinterpret_cast base to 0, [3, 4], strides
1982 /// ```
1983 /// Because we know (thanks to the type of the input memref) that variable
1984 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
1985 ///
1986 /// Similarly, the following sequence would be replaced:
1987 /// ```
1988 /// c0 = arith.constant 0
1989 /// c4 = arith.constant 4
1990 /// base, offset, sizes, strides =
1991 /// extract_strided_metadata memref : memref<3x4xty>
1992 /// dst = reinterpret_cast base to c0, [3, c4], strides
1993 /// ```
1994 /// Because we know that `offset`and `c0` will hold 0
1995 /// and `c4` will hold 4.
1996 struct ReinterpretCastOpExtractStridedMetadataFolder
1997  : public OpRewritePattern<ReinterpretCastOp> {
1998 public:
2000 
2001  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2002  PatternRewriter &rewriter) const override {
2003  auto extractStridedMetadata =
2004  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2005  if (!extractStridedMetadata)
2006  return failure();
2007  // Check if the reinterpret cast reconstructs a memref with the exact same
2008  // properties as the extract strided metadata.
2009 
2010  // First, check that the strides are the same.
2011  SmallVector<OpFoldResult> extractStridesOfr =
2012  extractStridedMetadata.getConstifiedMixedStrides();
2013  SmallVector<OpFoldResult> reinterpretStridesOfr =
2014  op.getConstifiedMixedStrides();
2015  if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2016  return failure();
2017 
2018  unsigned rank = op.getType().getRank();
2019  for (unsigned i = 0; i < rank; ++i) {
2020  if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2021  return failure();
2022  }
2023 
2024  // Second, check the sizes.
2025  assert(extractStridedMetadata.getSizes().size() ==
2026  op.getMixedSizes().size() &&
2027  "Strides and sizes rank must match");
2028  SmallVector<OpFoldResult> extractSizesOfr =
2029  extractStridedMetadata.getConstifiedMixedSizes();
2030  SmallVector<OpFoldResult> reinterpretSizesOfr =
2031  op.getConstifiedMixedSizes();
2032  for (unsigned i = 0; i < rank; ++i) {
2033  if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2034  return failure();
2035  }
2036  // Finally, check the offset.
2037  assert(op.getMixedOffsets().size() == 1 &&
2038  "reinterpret_cast with more than one offset should have been "
2039  "rejected by the verifier");
2040  OpFoldResult extractOffsetOfr =
2041  extractStridedMetadata.getConstifiedMixedOffset();
2042  OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2043  if (extractOffsetOfr != reinterpretOffsetOfr)
2044  return failure();
2045 
2046  // At this point, we know that the back and forth between extract strided
2047  // metadata and reinterpret cast is a noop. However, the final type of the
2048  // reinterpret cast may not be exactly the same as the original memref.
2049  // E.g., it could be changing a dimension from static to dynamic. Check that
2050  // here and add a cast if necessary.
2051  Type srcTy = extractStridedMetadata.getSource().getType();
2052  if (srcTy == op.getResult().getType())
2053  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2054  else
2055  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2056  extractStridedMetadata.getSource());
2057 
2058  return success();
2059  }
2060 };
2061 } // namespace
2062 
2063 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2064  MLIRContext *context) {
2065  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2066 }
2067 
2068 //===----------------------------------------------------------------------===//
2069 // Reassociative reshape ops
2070 //===----------------------------------------------------------------------===//
2071 
2072 void CollapseShapeOp::getAsmResultNames(
2073  function_ref<void(Value, StringRef)> setNameFn) {
2074  setNameFn(getResult(), "collapse_shape");
2075 }
2076 
2077 void ExpandShapeOp::getAsmResultNames(
2078  function_ref<void(Value, StringRef)> setNameFn) {
2079  setNameFn(getResult(), "expand_shape");
2080 }
2081 
2082 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2083 /// result and operand. Layout maps are verified separately.
2084 ///
2085 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2086 /// allowed in a reassocation group.
2087 static LogicalResult
2089  ArrayRef<int64_t> expandedShape,
2090  ArrayRef<ReassociationIndices> reassociation,
2091  bool allowMultipleDynamicDimsPerGroup) {
2092  // There must be one reassociation group per collapsed dimension.
2093  if (collapsedShape.size() != reassociation.size())
2094  return op->emitOpError("invalid number of reassociation groups: found ")
2095  << reassociation.size() << ", expected " << collapsedShape.size();
2096 
2097  // The next expected expanded dimension index (while iterating over
2098  // reassociation indices).
2099  int64_t nextDim = 0;
2100  for (const auto &it : llvm::enumerate(reassociation)) {
2101  ReassociationIndices group = it.value();
2102  int64_t collapsedDim = it.index();
2103 
2104  bool foundDynamic = false;
2105  for (int64_t expandedDim : group) {
2106  if (expandedDim != nextDim++)
2107  return op->emitOpError("reassociation indices must be contiguous");
2108 
2109  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2110  return op->emitOpError("reassociation index ")
2111  << expandedDim << " is out of bounds";
2112 
2113  // Check if there are multiple dynamic dims in a reassociation group.
2114  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2115  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2116  return op->emitOpError(
2117  "at most one dimension in a reassociation group may be dynamic");
2118  foundDynamic = true;
2119  }
2120  }
2121 
2122  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2123  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2124  return op->emitOpError("collapsed dim (")
2125  << collapsedDim
2126  << ") must be dynamic if and only if reassociation group is "
2127  "dynamic";
2128 
2129  // If all dims in the reassociation group are static, the size of the
2130  // collapsed dim can be verified.
2131  if (!foundDynamic) {
2132  int64_t groupSize = 1;
2133  for (int64_t expandedDim : group)
2134  groupSize *= expandedShape[expandedDim];
2135  if (groupSize != collapsedShape[collapsedDim])
2136  return op->emitOpError("collapsed dim size (")
2137  << collapsedShape[collapsedDim]
2138  << ") must equal reassociation group size (" << groupSize << ")";
2139  }
2140  }
2141 
2142  if (collapsedShape.empty()) {
2143  // Rank 0: All expanded dimensions must be 1.
2144  for (int64_t d : expandedShape)
2145  if (d != 1)
2146  return op->emitOpError(
2147  "rank 0 memrefs can only be extended/collapsed with/from ones");
2148  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2149  // Rank >= 1: Number of dimensions among all reassociation groups must match
2150  // the result memref rank.
2151  return op->emitOpError("expanded rank (")
2152  << expandedShape.size()
2153  << ") inconsistent with number of reassociation indices (" << nextDim
2154  << ")";
2155  }
2156 
2157  return success();
2158 }
2159 
2160 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2161  return getSymbolLessAffineMaps(getReassociationExprs());
2162 }
2163 
2164 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2166  getReassociationIndices());
2167 }
2168 
2169 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2170  return getSymbolLessAffineMaps(getReassociationExprs());
2171 }
2172 
2173 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2175  getReassociationIndices());
2176 }
2177 
2178 /// Compute the layout map after expanding a given source MemRef type with the
2179 /// specified reassociation indices.
2181 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2182  ArrayRef<ReassociationIndices> reassociation) {
2183  int64_t srcOffset;
2184  SmallVector<int64_t> srcStrides;
2185  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2186  return failure();
2187  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2188 
2189  // 1-1 mapping between srcStrides and reassociation packs.
2190  // Each srcStride starts with the given value and gets expanded according to
2191  // the proper entries in resultShape.
2192  // Example:
2193  // srcStrides = [10000, 1 , 100 ],
2194  // reassociations = [ [0], [1], [2, 3, 4]],
2195  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2196  // -> For the purpose of stride calculation, the useful sizes are:
2197  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2198  // resultStrides = [10000, 1, 600, 200, 100]
2199  // Note that a stride does not get expanded along the first entry of each
2200  // shape pack.
2201  SmallVector<int64_t> reverseResultStrides;
2202  reverseResultStrides.reserve(resultShape.size());
2203  unsigned shapeIndex = resultShape.size() - 1;
2204  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2205  ReassociationIndices reassoc = std::get<0>(it);
2206  int64_t currentStrideToExpand = std::get<1>(it);
2207  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2208  reverseResultStrides.push_back(currentStrideToExpand);
2209  currentStrideToExpand =
2210  (SaturatedInteger::wrap(currentStrideToExpand) *
2211  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2212  .asInteger();
2213  }
2214  }
2215  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2216  resultStrides.resize(resultShape.size(), 1);
2217  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2218 }
2219 
2220 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2221  MemRefType srcType, ArrayRef<int64_t> resultShape,
2222  ArrayRef<ReassociationIndices> reassociation) {
2223  if (srcType.getLayout().isIdentity()) {
2224  // If the source is contiguous (i.e., no layout map specified), so is the
2225  // result.
2226  MemRefLayoutAttrInterface layout;
2227  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2228  srcType.getMemorySpace());
2229  }
2230 
2231  // Source may not be contiguous. Compute the layout map.
2232  FailureOr<StridedLayoutAttr> computedLayout =
2233  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2234  if (failed(computedLayout))
2235  return failure();
2236  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2237  srcType.getMemorySpace());
2238 }
2239 
2240 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2241  ArrayRef<int64_t> resultShape, Value src,
2242  ArrayRef<ReassociationIndices> reassociation) {
2243  // Only ranked memref source values are supported.
2244  auto srcType = llvm::cast<MemRefType>(src.getType());
2245  FailureOr<MemRefType> resultType =
2246  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2247  // Failure of this assertion usually indicates a problem with the source
2248  // type, e.g., could not get strides/offset.
2249  assert(succeeded(resultType) && "could not compute layout");
2250  build(builder, result, *resultType, src, reassociation);
2251 }
2252 
2254  MemRefType srcType = getSrcType();
2255  MemRefType resultType = getResultType();
2256 
2257  if (srcType.getRank() > resultType.getRank()) {
2258  auto r0 = srcType.getRank();
2259  auto r1 = resultType.getRank();
2260  return emitOpError("has source rank ")
2261  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2262  << r0 << " > " << r1 << ").";
2263  }
2264 
2265  // Verify result shape.
2266  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2267  resultType.getShape(),
2268  getReassociationIndices(),
2269  /*allowMultipleDynamicDimsPerGroup=*/false)))
2270  return failure();
2271 
2272  // Compute expected result type (including layout map).
2273  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2274  srcType, resultType.getShape(), getReassociationIndices());
2275  if (failed(expectedResultType))
2276  return emitOpError("invalid source layout map");
2277 
2278  // Check actual result type.
2279  if (*expectedResultType != resultType)
2280  return emitOpError("expected expanded type to be ")
2281  << *expectedResultType << " but found " << resultType;
2282 
2283  return success();
2284 }
2285 
2286 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2287  MLIRContext *context) {
2290  context);
2291 }
2292 
2293 /// Compute the layout map after collapsing a given source MemRef type with the
2294 /// specified reassociation indices.
2295 ///
2296 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2297 /// not possible to check this by inspecting a MemRefType in the general case.
2298 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2299 /// be valid (and thus accepted by this function) unless `strict = true`.
2301 computeCollapsedLayoutMap(MemRefType srcType,
2302  ArrayRef<ReassociationIndices> reassociation,
2303  bool strict = false) {
2304  int64_t srcOffset;
2305  SmallVector<int64_t> srcStrides;
2306  auto srcShape = srcType.getShape();
2307  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2308  return failure();
2309 
2310  // The result stride of a reassociation group is the stride of the last entry
2311  // of the reassociation. (TODO: Should be the minimum stride in the
2312  // reassociation because strides are not necessarily sorted. E.g., when using
2313  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2314  // strides are meaningless and could have any arbitrary value.
2315  SmallVector<int64_t> resultStrides;
2316  resultStrides.reserve(reassociation.size());
2317  for (const ReassociationIndices &reassoc : reassociation) {
2318  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2319  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2320  ref = ref.drop_back();
2321  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2322  resultStrides.push_back(srcStrides[ref.back()]);
2323  } else {
2324  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2325  // the corresponding stride may have to be skipped. (See above comment.)
2326  // Therefore, the result stride cannot be statically determined and must
2327  // be dynamic.
2328  resultStrides.push_back(ShapedType::kDynamic);
2329  }
2330  }
2331 
2332  // Validate that each reassociation group is contiguous.
2333  unsigned resultStrideIndex = resultStrides.size() - 1;
2334  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2335  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2336  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2337  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2338  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2339 
2340  // Both source and result stride must have the same static value. In that
2341  // case, we can be sure, that the dimensions are collapsible (because they
2342  // are contiguous).
2343  // If `strict = false` (default during op verification), we accept cases
2344  // where one or both strides are dynamic. This is best effort: We reject
2345  // ops where obviously non-contiguous dims are collapsed, but accept ops
2346  // where we cannot be sure statically. Such ops may fail at runtime. See
2347  // the op documentation for details.
2348  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2349  if (strict && (stride.saturated || srcStride.saturated))
2350  return failure();
2351 
2352  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2353  return failure();
2354  }
2355  }
2356  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2357 }
2358 
2359 bool CollapseShapeOp::isGuaranteedCollapsible(
2360  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2361  // MemRefs with identity layout are always collapsible.
2362  if (srcType.getLayout().isIdentity())
2363  return true;
2364 
2365  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2366  /*strict=*/true));
2367 }
2368 
2369 MemRefType CollapseShapeOp::computeCollapsedType(
2370  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2371  SmallVector<int64_t> resultShape;
2372  resultShape.reserve(reassociation.size());
2373  for (const ReassociationIndices &group : reassociation) {
2374  auto groupSize = SaturatedInteger::wrap(1);
2375  for (int64_t srcDim : group)
2376  groupSize =
2377  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2378  resultShape.push_back(groupSize.asInteger());
2379  }
2380 
2381  if (srcType.getLayout().isIdentity()) {
2382  // If the source is contiguous (i.e., no layout map specified), so is the
2383  // result.
2384  MemRefLayoutAttrInterface layout;
2385  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2386  srcType.getMemorySpace());
2387  }
2388 
2389  // Source may not be fully contiguous. Compute the layout map.
2390  // Note: Dimensions that are collapsed into a single dim are assumed to be
2391  // contiguous.
2392  FailureOr<StridedLayoutAttr> computedLayout =
2393  computeCollapsedLayoutMap(srcType, reassociation);
2394  assert(succeeded(computedLayout) &&
2395  "invalid source layout map or collapsing non-contiguous dims");
2396  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2397  srcType.getMemorySpace());
2398 }
2399 
2400 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2401  ArrayRef<ReassociationIndices> reassociation,
2402  ArrayRef<NamedAttribute> attrs) {
2403  auto srcType = llvm::cast<MemRefType>(src.getType());
2404  MemRefType resultType =
2405  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2406  build(b, result, resultType, src, attrs);
2408  getReassociationIndicesAttribute(b, reassociation));
2409 }
2410 
2412  MemRefType srcType = getSrcType();
2413  MemRefType resultType = getResultType();
2414 
2415  if (srcType.getRank() < resultType.getRank()) {
2416  auto r0 = srcType.getRank();
2417  auto r1 = resultType.getRank();
2418  return emitOpError("has source rank ")
2419  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2420  << r0 << " < " << r1 << ").";
2421  }
2422 
2423  // Verify result shape.
2424  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2425  srcType.getShape(), getReassociationIndices(),
2426  /*allowMultipleDynamicDimsPerGroup=*/true)))
2427  return failure();
2428 
2429  // Compute expected result type (including layout map).
2430  MemRefType expectedResultType;
2431  if (srcType.getLayout().isIdentity()) {
2432  // If the source is contiguous (i.e., no layout map specified), so is the
2433  // result.
2434  MemRefLayoutAttrInterface layout;
2435  expectedResultType =
2436  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2437  srcType.getMemorySpace());
2438  } else {
2439  // Source may not be fully contiguous. Compute the layout map.
2440  // Note: Dimensions that are collapsed into a single dim are assumed to be
2441  // contiguous.
2442  FailureOr<StridedLayoutAttr> computedLayout =
2443  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2444  if (failed(computedLayout))
2445  return emitOpError(
2446  "invalid source layout map or collapsing non-contiguous dims");
2447  expectedResultType =
2448  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2449  *computedLayout, srcType.getMemorySpace());
2450  }
2451 
2452  if (expectedResultType != resultType)
2453  return emitOpError("expected collapsed type to be ")
2454  << expectedResultType << " but found " << resultType;
2455 
2456  return success();
2457 }
2458 
2460  : public OpRewritePattern<CollapseShapeOp> {
2461 public:
2463 
2464  LogicalResult matchAndRewrite(CollapseShapeOp op,
2465  PatternRewriter &rewriter) const override {
2466  auto cast = op.getOperand().getDefiningOp<CastOp>();
2467  if (!cast)
2468  return failure();
2469 
2470  if (!CastOp::canFoldIntoConsumerOp(cast))
2471  return failure();
2472 
2473  Type newResultType = CollapseShapeOp::computeCollapsedType(
2474  llvm::cast<MemRefType>(cast.getOperand().getType()),
2475  op.getReassociationIndices());
2476 
2477  if (newResultType == op.getResultType()) {
2478  rewriter.modifyOpInPlace(
2479  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2480  } else {
2481  Value newOp = rewriter.create<CollapseShapeOp>(
2482  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2483  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2484  }
2485  return success();
2486  }
2487 };
2488 
2489 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2490  MLIRContext *context) {
2494 }
2495 
2496 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2497  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2498  adaptor.getOperands());
2499 }
2500 
2501 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2502  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2503  adaptor.getOperands());
2504 }
2505 
2506 //===----------------------------------------------------------------------===//
2507 // ReshapeOp
2508 //===----------------------------------------------------------------------===//
2509 
2510 void ReshapeOp::getAsmResultNames(
2511  function_ref<void(Value, StringRef)> setNameFn) {
2512  setNameFn(getResult(), "reshape");
2513 }
2514 
2516  Type operandType = getSource().getType();
2517  Type resultType = getResult().getType();
2518 
2519  Type operandElementType =
2520  llvm::cast<ShapedType>(operandType).getElementType();
2521  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2522  if (operandElementType != resultElementType)
2523  return emitOpError("element types of source and destination memref "
2524  "types should be the same");
2525 
2526  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2527  if (!operandMemRefType.getLayout().isIdentity())
2528  return emitOpError("source memref type should have identity affine map");
2529 
2530  int64_t shapeSize =
2531  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2532  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2533  if (resultMemRefType) {
2534  if (!resultMemRefType.getLayout().isIdentity())
2535  return emitOpError("result memref type should have identity affine map");
2536  if (shapeSize == ShapedType::kDynamic)
2537  return emitOpError("cannot use shape operand with dynamic length to "
2538  "reshape to statically-ranked memref type");
2539  if (shapeSize != resultMemRefType.getRank())
2540  return emitOpError(
2541  "length of shape operand differs from the result's memref rank");
2542  }
2543  return success();
2544 }
2545 
2546 //===----------------------------------------------------------------------===//
2547 // StoreOp
2548 //===----------------------------------------------------------------------===//
2549 
2551  if (getNumOperands() != 2 + getMemRefType().getRank())
2552  return emitOpError("store index operand count not equal to memref rank");
2553 
2554  return success();
2555 }
2556 
2557 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2558  SmallVectorImpl<OpFoldResult> &results) {
2559  /// store(memrefcast) -> store
2560  return foldMemRefCast(*this, getValueToStore());
2561 }
2562 
2563 //===----------------------------------------------------------------------===//
2564 // SubViewOp
2565 //===----------------------------------------------------------------------===//
2566 
2567 void SubViewOp::getAsmResultNames(
2568  function_ref<void(Value, StringRef)> setNameFn) {
2569  setNameFn(getResult(), "subview");
2570 }
2571 
2572 /// A subview result type can be fully inferred from the source type and the
2573 /// static representation of offsets, sizes and strides. Special sentinels
2574 /// encode the dynamic case.
2575 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2576  ArrayRef<int64_t> staticOffsets,
2577  ArrayRef<int64_t> staticSizes,
2578  ArrayRef<int64_t> staticStrides) {
2579  unsigned rank = sourceMemRefType.getRank();
2580  (void)rank;
2581  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2582  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2583  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2584 
2585  // Extract source offset and strides.
2586  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
2587 
2588  // Compute target offset whose value is:
2589  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2590  int64_t targetOffset = sourceOffset;
2591  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2592  auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2593  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2594  SaturatedInteger::wrap(staticOffset) *
2595  SaturatedInteger::wrap(targetStride))
2596  .asInteger();
2597  }
2598 
2599  // Compute target stride whose value is:
2600  // `sourceStrides_i * staticStrides_i`.
2601  SmallVector<int64_t, 4> targetStrides;
2602  targetStrides.reserve(staticOffsets.size());
2603  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2604  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2605  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2606  SaturatedInteger::wrap(staticStride))
2607  .asInteger());
2608  }
2609 
2610  // The type is now known.
2611  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2612  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2613  targetOffset, targetStrides),
2614  sourceMemRefType.getMemorySpace());
2615 }
2616 
2617 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2618  ArrayRef<OpFoldResult> offsets,
2619  ArrayRef<OpFoldResult> sizes,
2620  ArrayRef<OpFoldResult> strides) {
2621  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2622  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2623  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2624  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2625  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2626  if (!hasValidSizesOffsets(staticOffsets))
2627  return {};
2628  if (!hasValidSizesOffsets(staticSizes))
2629  return {};
2630  if (!hasValidStrides(staticStrides))
2631  return {};
2632  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2633  staticSizes, staticStrides);
2634 }
2635 
2636 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2637  MemRefType sourceRankedTensorType,
2638  ArrayRef<int64_t> offsets,
2639  ArrayRef<int64_t> sizes,
2640  ArrayRef<int64_t> strides) {
2641  auto inferredType = llvm::cast<MemRefType>(
2642  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2643  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2644  "expected ");
2645  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2646  return inferredType;
2647 
2648  // Compute which dimensions are dropped.
2649  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2650  computeRankReductionMask(inferredType.getShape(), resultShape);
2651  assert(dimsToProject.has_value() && "invalid rank reduction");
2652 
2653  // Compute the layout and result type.
2654  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2655  SmallVector<int64_t> rankReducedStrides;
2656  rankReducedStrides.reserve(resultShape.size());
2657  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2658  if (!dimsToProject->contains(idx))
2659  rankReducedStrides.push_back(value);
2660  }
2661  return MemRefType::get(resultShape, inferredType.getElementType(),
2662  StridedLayoutAttr::get(inferredLayout.getContext(),
2663  inferredLayout.getOffset(),
2664  rankReducedStrides),
2665  inferredType.getMemorySpace());
2666 }
2667 
2668 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2669  MemRefType sourceRankedTensorType,
2670  ArrayRef<OpFoldResult> offsets,
2671  ArrayRef<OpFoldResult> sizes,
2672  ArrayRef<OpFoldResult> strides) {
2673  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2674  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2675  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2676  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2677  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2678  return SubViewOp::inferRankReducedResultType(
2679  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2680  staticStrides);
2681 }
2682 
2683 // Build a SubViewOp with mixed static and dynamic entries and custom result
2684 // type. If the type passed is nullptr, it is inferred.
2685 void SubViewOp::build(OpBuilder &b, OperationState &result,
2686  MemRefType resultType, Value source,
2687  ArrayRef<OpFoldResult> offsets,
2688  ArrayRef<OpFoldResult> sizes,
2689  ArrayRef<OpFoldResult> strides,
2690  ArrayRef<NamedAttribute> attrs) {
2691  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2692  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2693  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2694  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2695  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2696  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2697  // Structuring implementation this way avoids duplication between builders.
2698  if (!resultType) {
2699  resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2700  sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2701  }
2702  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2703  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2704  b.getDenseI64ArrayAttr(staticSizes),
2705  b.getDenseI64ArrayAttr(staticStrides));
2706  result.addAttributes(attrs);
2707 }
2708 
2709 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2710 // type.
2711 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2712  ArrayRef<OpFoldResult> offsets,
2713  ArrayRef<OpFoldResult> sizes,
2714  ArrayRef<OpFoldResult> strides,
2715  ArrayRef<NamedAttribute> attrs) {
2716  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2717 }
2718 
2719 // Build a SubViewOp with static entries and inferred result type.
2720 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2721  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2722  ArrayRef<int64_t> strides,
2723  ArrayRef<NamedAttribute> attrs) {
2724  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2725  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2726  return b.getI64IntegerAttr(v);
2727  }));
2728  SmallVector<OpFoldResult> sizeValues =
2729  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2730  return b.getI64IntegerAttr(v);
2731  }));
2732  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2733  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2734  return b.getI64IntegerAttr(v);
2735  }));
2736  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2737 }
2738 
2739 // Build a SubViewOp with dynamic entries and custom result type. If the
2740 // type passed is nullptr, it is inferred.
2741 void SubViewOp::build(OpBuilder &b, OperationState &result,
2742  MemRefType resultType, Value source,
2743  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2744  ArrayRef<int64_t> strides,
2745  ArrayRef<NamedAttribute> attrs) {
2746  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2747  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2748  return b.getI64IntegerAttr(v);
2749  }));
2750  SmallVector<OpFoldResult> sizeValues =
2751  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2752  return b.getI64IntegerAttr(v);
2753  }));
2754  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2755  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2756  return b.getI64IntegerAttr(v);
2757  }));
2758  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2759  attrs);
2760 }
2761 
2762 // Build a SubViewOp with dynamic entries and custom result type. If the type
2763 // passed is nullptr, it is inferred.
2764 void SubViewOp::build(OpBuilder &b, OperationState &result,
2765  MemRefType resultType, Value source, ValueRange offsets,
2766  ValueRange sizes, ValueRange strides,
2767  ArrayRef<NamedAttribute> attrs) {
2768  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2769  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2770  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2771  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2772  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2773  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2774  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2775 }
2776 
2777 // Build a SubViewOp with dynamic entries and inferred result type.
2778 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2779  ValueRange offsets, ValueRange sizes, ValueRange strides,
2780  ArrayRef<NamedAttribute> attrs) {
2781  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2782 }
2783 
2784 /// For ViewLikeOpInterface.
2785 Value SubViewOp::getViewSource() { return getSource(); }
2786 
2787 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2788 /// static value).
2789 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2790  int64_t t1Offset, t2Offset;
2791  SmallVector<int64_t> t1Strides, t2Strides;
2792  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2793  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2794  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2795 }
2796 
2797 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2798 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2799 /// marked as dropped in `droppedDims`.
2800 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2801  const llvm::SmallBitVector &droppedDims) {
2802  assert(size_t(t1.getRank()) == droppedDims.size() && "incorrect number of bits");
2803  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2804  "incorrect number of dropped dims");
2805  int64_t t1Offset, t2Offset;
2806  SmallVector<int64_t> t1Strides, t2Strides;
2807  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2808  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2809  if (failed(res1) || failed(res2))
2810  return false;
2811  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2812  if (droppedDims[i])
2813  continue;
2814  if (t1Strides[i] != t2Strides[j])
2815  return false;
2816  ++j;
2817  }
2818  return true;
2819 }
2820 
2822  Operation *op, Type expectedType) {
2823  auto memrefType = llvm::cast<ShapedType>(expectedType);
2824  switch (result) {
2826  return success();
2828  return op->emitError("expected result rank to be smaller or equal to ")
2829  << "the source rank. ";
2831  return op->emitError("expected result type to be ")
2832  << expectedType
2833  << " or a rank-reduced version. (mismatch of result sizes) ";
2835  return op->emitError("expected result element type to be ")
2836  << memrefType.getElementType();
2838  return op->emitError("expected result and source memory spaces to match.");
2840  return op->emitError("expected result type to be ")
2841  << expectedType
2842  << " or a rank-reduced version. (mismatch of result layout) ";
2843  }
2844  llvm_unreachable("unexpected subview verification result");
2845 }
2846 
2847 /// Verifier for SubViewOp.
2849  MemRefType baseType = getSourceType();
2850  MemRefType subViewType = getType();
2851 
2852  // The base memref and the view memref should be in the same memory space.
2853  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2854  return emitError("different memory spaces specified for base memref "
2855  "type ")
2856  << baseType << " and subview memref type " << subViewType;
2857 
2858  // Verify that the base memref type has a strided layout map.
2859  if (!isStrided(baseType))
2860  return emitError("base type ") << baseType << " is not strided";
2861 
2862  // Compute the expected result type, assuming that there are no rank
2863  // reductions.
2864  auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2865  baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2866 
2867  // Verify all properties of a shaped type: rank, element type and dimension
2868  // sizes. This takes into account potential rank reductions.
2869  auto shapedTypeVerification = isRankReducedType(
2870  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2871  if (shapedTypeVerification != SliceVerificationResult::Success)
2872  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2873 
2874  // Make sure that the memory space did not change.
2875  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2877  *this, expectedType);
2878 
2879  // Verify the offset of the layout map.
2880  if (!haveCompatibleOffsets(expectedType, subViewType))
2882  *this, expectedType);
2883 
2884  // The only thing that's left to verify now are the strides. First, compute
2885  // the unused dimensions due to rank reductions. We have to look at sizes and
2886  // strides to decide which dimensions were dropped. This function also
2887  // partially verifies strides in case of rank reductions.
2888  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
2889  getMixedSizes());
2890  if (failed(unusedDims))
2892  *this, expectedType);
2893 
2894  // Strides must match.
2895  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
2897  *this, expectedType);
2898 
2899  return success();
2900 }
2901 
2902 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2903  return os << "range " << range.offset << ":" << range.size << ":"
2904  << range.stride;
2905 }
2906 
2907 /// Return the list of Range (i.e. offset, size, stride). Each Range
2908 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2909 /// with `b` at location `loc`.
2910 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2911  OpBuilder &b, Location loc) {
2912  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2913  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2914  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2916  unsigned rank = ranks[0];
2917  res.reserve(rank);
2918  for (unsigned idx = 0; idx < rank; ++idx) {
2919  Value offset =
2920  op.isDynamicOffset(idx)
2921  ? op.getDynamicOffset(idx)
2922  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
2923  Value size =
2924  op.isDynamicSize(idx)
2925  ? op.getDynamicSize(idx)
2926  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
2927  Value stride =
2928  op.isDynamicStride(idx)
2929  ? op.getDynamicStride(idx)
2930  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
2931  res.emplace_back(Range{offset, size, stride});
2932  }
2933  return res;
2934 }
2935 
2936 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2937 /// to deduce the result type for the given `sourceType`. Additionally, reduce
2938 /// the rank of the inferred result type if `currentResultType` is lower rank
2939 /// than `currentSourceType`. Use this signature if `sourceType` is updated
2940 /// together with the result type. In this case, it is important to compute
2941 /// the dropped dimensions using `currentSourceType` whose strides align with
2942 /// `currentResultType`.
2944  MemRefType currentResultType, MemRefType currentSourceType,
2945  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2946  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2947  auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2948  sourceType, mixedOffsets, mixedSizes, mixedStrides));
2950  currentSourceType, currentResultType, mixedSizes);
2951  if (failed(unusedDims))
2952  return nullptr;
2953 
2954  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
2955  SmallVector<int64_t> shape, strides;
2956  unsigned numDimsAfterReduction =
2957  nonRankReducedType.getRank() - unusedDims->count();
2958  shape.reserve(numDimsAfterReduction);
2959  strides.reserve(numDimsAfterReduction);
2960  for (const auto &[idx, size, stride] :
2961  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
2962  nonRankReducedType.getShape(), layout.getStrides())) {
2963  if (unusedDims->test(idx))
2964  continue;
2965  shape.push_back(size);
2966  strides.push_back(stride);
2967  }
2968 
2969  return MemRefType::get(shape, nonRankReducedType.getElementType(),
2970  StridedLayoutAttr::get(sourceType.getContext(),
2971  layout.getOffset(), strides),
2972  nonRankReducedType.getMemorySpace());
2973 }
2974 
2976  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
2977  auto memrefType = llvm::cast<MemRefType>(memref.getType());
2978  unsigned rank = memrefType.getRank();
2979  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2980  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
2981  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2982  auto targetType =
2983  llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
2984  targetShape, memrefType, offsets, sizes, strides));
2985  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
2986  sizes, strides);
2987 }
2988 
2989 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
2990  Value value,
2991  ArrayRef<int64_t> desiredShape) {
2992  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
2993  assert(sourceMemrefType && "not a ranked memref type");
2994  auto sourceShape = sourceMemrefType.getShape();
2995  if (sourceShape.equals(desiredShape))
2996  return value;
2997  auto maybeRankReductionMask =
2998  mlir::computeRankReductionMask(sourceShape, desiredShape);
2999  if (!maybeRankReductionMask)
3000  return failure();
3001  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3002 }
3003 
3004 /// Helper method to check if a `subview` operation is trivially a no-op. This
3005 /// is the case if the all offsets are zero, all strides are 1, and the source
3006 /// shape is same as the size of the subview. In such cases, the subview can
3007 /// be folded into its source.
3008 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3009  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3010  return false;
3011 
3012  auto mixedOffsets = subViewOp.getMixedOffsets();
3013  auto mixedSizes = subViewOp.getMixedSizes();
3014  auto mixedStrides = subViewOp.getMixedStrides();
3015 
3016  // Check offsets are zero.
3017  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3018  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3019  return !intValue || intValue.value() != 0;
3020  }))
3021  return false;
3022 
3023  // Check strides are one.
3024  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3025  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3026  return !intValue || intValue.value() != 1;
3027  }))
3028  return false;
3029 
3030  // Check all size values are static and matches the (static) source shape.
3031  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3032  for (const auto &size : llvm::enumerate(mixedSizes)) {
3033  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3034  if (!intValue || *intValue != sourceShape[size.index()])
3035  return false;
3036  }
3037  // All conditions met. The `SubViewOp` is foldable as a no-op.
3038  return true;
3039 }
3040 
3041 namespace {
3042 /// Pattern to rewrite a subview op with MemRefCast arguments.
3043 /// This essentially pushes memref.cast past its consuming subview when
3044 /// `canFoldIntoConsumerOp` is true.
3045 ///
3046 /// Example:
3047 /// ```
3048 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3049 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3050 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3051 /// ```
3052 /// is rewritten into:
3053 /// ```
3054 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3055 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3056 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3057 /// ```
3058 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3059 public:
3061 
3062  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3063  PatternRewriter &rewriter) const override {
3064  // Any constant operand, just return to let SubViewOpConstantFolder kick
3065  // in.
3066  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3067  return matchPattern(operand, matchConstantIndex());
3068  }))
3069  return failure();
3070 
3071  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3072  if (!castOp)
3073  return failure();
3074 
3075  if (!CastOp::canFoldIntoConsumerOp(castOp))
3076  return failure();
3077 
3078  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3079  // the MemRefCastOp source operand type to infer the result type and the
3080  // current SubViewOp source operand type to compute the dropped dimensions
3081  // if the operation is rank-reducing.
3082  auto resultType = getCanonicalSubViewResultType(
3083  subViewOp.getType(), subViewOp.getSourceType(),
3084  llvm::cast<MemRefType>(castOp.getSource().getType()),
3085  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3086  subViewOp.getMixedStrides());
3087  if (!resultType)
3088  return failure();
3089 
3090  Value newSubView = rewriter.create<SubViewOp>(
3091  subViewOp.getLoc(), resultType, castOp.getSource(),
3092  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3093  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3094  subViewOp.getStaticStrides());
3095  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3096  newSubView);
3097  return success();
3098  }
3099 };
3100 
3101 /// Canonicalize subview ops that are no-ops. When the source shape is not
3102 /// same as a result shape due to use of `affine_map`.
3103 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3104 public:
3106 
3107  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3108  PatternRewriter &rewriter) const override {
3109  if (!isTrivialSubViewOp(subViewOp))
3110  return failure();
3111  if (subViewOp.getSourceType() == subViewOp.getType()) {
3112  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3113  return success();
3114  }
3115  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3116  subViewOp.getSource());
3117  return success();
3118  }
3119 };
3120 } // namespace
3121 
3122 /// Return the canonical type of the result of a subview.
3124  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3125  ArrayRef<OpFoldResult> mixedSizes,
3126  ArrayRef<OpFoldResult> mixedStrides) {
3127  // Infer a memref type without taking into account any rank reductions.
3128  auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3129  mixedSizes, mixedStrides);
3130  if (!resTy)
3131  return {};
3132  MemRefType nonReducedType = cast<MemRefType>(resTy);
3133 
3134  // Directly return the non-rank reduced type if there are no dropped dims.
3135  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3136  if (droppedDims.none())
3137  return nonReducedType;
3138 
3139  // Take the strides and offset from the non-rank reduced type.
3140  auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
3141 
3142  // Drop dims from shape and strides.
3143  SmallVector<int64_t> targetShape;
3144  SmallVector<int64_t> targetStrides;
3145  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3146  if (droppedDims.test(i))
3147  continue;
3148  targetStrides.push_back(nonReducedStrides[i]);
3149  targetShape.push_back(nonReducedType.getDimSize(i));
3150  }
3151 
3152  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3153  StridedLayoutAttr::get(nonReducedType.getContext(),
3154  offset, targetStrides),
3155  nonReducedType.getMemorySpace());
3156  }
3157 };
3158 
3159 /// A canonicalizer wrapper to replace SubViewOps.
3161  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3162  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3163  }
3164 };
3165 
3166 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3167  MLIRContext *context) {
3168  results
3171  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3172 }
3173 
3174 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3175  auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
3176  auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
3177 
3178  if (resultShapedType.hasStaticShape() &&
3179  resultShapedType == sourceShapedType) {
3180  return getViewSource();
3181  }
3182 
3183  // Fold subview(subview(x)), where both subviews have the same size and the
3184  // second subview's offsets are all zero. (I.e., the second subview is a
3185  // no-op.)
3186  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3187  auto srcSizes = srcSubview.getMixedSizes();
3188  auto sizes = getMixedSizes();
3189  auto offsets = getMixedOffsets();
3190  bool allOffsetsZero = llvm::all_of(
3191  offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3192  auto strides = getMixedStrides();
3193  bool allStridesOne = llvm::all_of(
3194  strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3195  bool allSizesSame = llvm::equal(sizes, srcSizes);
3196  if (allOffsetsZero && allStridesOne && allSizesSame &&
3197  resultShapedType == sourceShapedType)
3198  return getViewSource();
3199  }
3200 
3201  return {};
3202 }
3203 
3204 //===----------------------------------------------------------------------===//
3205 // TransposeOp
3206 //===----------------------------------------------------------------------===//
3207 
3208 void TransposeOp::getAsmResultNames(
3209  function_ref<void(Value, StringRef)> setNameFn) {
3210  setNameFn(getResult(), "transpose");
3211 }
3212 
3213 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3214 static MemRefType inferTransposeResultType(MemRefType memRefType,
3215  AffineMap permutationMap) {
3216  auto originalSizes = memRefType.getShape();
3217  auto [originalStrides, offset] = getStridesAndOffset(memRefType);
3218  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3219 
3220  // Compute permuted sizes and strides.
3221  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3222  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3223 
3224  return MemRefType::Builder(memRefType)
3225  .setShape(sizes)
3226  .setLayout(
3227  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3228 }
3229 
3230 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3231  AffineMapAttr permutation,
3232  ArrayRef<NamedAttribute> attrs) {
3233  auto permutationMap = permutation.getValue();
3234  assert(permutationMap);
3235 
3236  auto memRefType = llvm::cast<MemRefType>(in.getType());
3237  // Compute result type.
3238  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3239 
3240  build(b, result, resultType, in, attrs);
3241  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3242 }
3243 
3244 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3246  p << " " << getIn() << " " << getPermutation();
3247  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3248  p << " : " << getIn().getType() << " to " << getType();
3249 }
3250 
3253  AffineMap permutation;
3254  MemRefType srcType, dstType;
3255  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3256  parser.parseOptionalAttrDict(result.attributes) ||
3257  parser.parseColonType(srcType) ||
3258  parser.resolveOperand(in, srcType, result.operands) ||
3259  parser.parseKeywordType("to", dstType) ||
3260  parser.addTypeToList(dstType, result.types))
3261  return failure();
3262 
3263  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3264  AffineMapAttr::get(permutation));
3265  return success();
3266 }
3267 
3269  if (!getPermutation().isPermutation())
3270  return emitOpError("expected a permutation map");
3271  if (getPermutation().getNumDims() != getIn().getType().getRank())
3272  return emitOpError("expected a permutation map of same rank as the input");
3273 
3274  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3275  auto resultType = llvm::cast<MemRefType>(getType());
3276  auto canonicalResultType = canonicalizeStridedLayout(
3277  inferTransposeResultType(srcType, getPermutation()));
3278 
3279  if (canonicalizeStridedLayout(resultType) != canonicalResultType)
3280  return emitOpError("result type ")
3281  << resultType
3282  << " is not equivalent to the canonical transposed input type "
3283  << canonicalResultType;
3284  return success();
3285 }
3286 
3287 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3288  // First check for identity permutation, we can fold it away if input and
3289  // result types are identical already.
3290  if (getPermutation().isIdentity() && getType() == getIn().getType())
3291  return getIn();
3292  // Fold two consecutive memref.transpose Ops into one by composing their
3293  // permutation maps.
3294  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3295  AffineMap composedPermutation =
3296  getPermutation().compose(otherTransposeOp.getPermutation());
3297  getInMutable().assign(otherTransposeOp.getIn());
3298  setPermutation(composedPermutation);
3299  return getResult();
3300  }
3301  return {};
3302 }
3303 
3304 //===----------------------------------------------------------------------===//
3305 // ViewOp
3306 //===----------------------------------------------------------------------===//
3307 
3308 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3309  setNameFn(getResult(), "view");
3310 }
3311 
3313  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3314  auto viewType = getType();
3315 
3316  // The base memref should have identity layout map (or none).
3317  if (!baseType.getLayout().isIdentity())
3318  return emitError("unsupported map for base memref type ") << baseType;
3319 
3320  // The result memref should have identity layout map (or none).
3321  if (!viewType.getLayout().isIdentity())
3322  return emitError("unsupported map for result memref type ") << viewType;
3323 
3324  // The base memref and the view memref should be in the same memory space.
3325  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3326  return emitError("different memory spaces specified for base memref "
3327  "type ")
3328  << baseType << " and view memref type " << viewType;
3329 
3330  // Verify that we have the correct number of sizes for the result type.
3331  unsigned numDynamicDims = viewType.getNumDynamicDims();
3332  if (getSizes().size() != numDynamicDims)
3333  return emitError("incorrect number of size operands for type ") << viewType;
3334 
3335  return success();
3336 }
3337 
3338 Value ViewOp::getViewSource() { return getSource(); }
3339 
3340 namespace {
3341 
3342 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3344 
3345  LogicalResult matchAndRewrite(ViewOp viewOp,
3346  PatternRewriter &rewriter) const override {
3347  // Return if none of the operands are constants.
3348  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3349  return matchPattern(operand, matchConstantIndex());
3350  }))
3351  return failure();
3352 
3353  // Get result memref type.
3354  auto memrefType = viewOp.getType();
3355 
3356  // Get offset from old memref view type 'memRefType'.
3357  int64_t oldOffset;
3358  SmallVector<int64_t, 4> oldStrides;
3359  if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3360  return failure();
3361  assert(oldOffset == 0 && "Expected 0 offset");
3362 
3363  SmallVector<Value, 4> newOperands;
3364 
3365  // Offset cannot be folded into result type.
3366 
3367  // Fold any dynamic dim operands which are produced by a constant.
3368  SmallVector<int64_t, 4> newShapeConstants;
3369  newShapeConstants.reserve(memrefType.getRank());
3370 
3371  unsigned dynamicDimPos = 0;
3372  unsigned rank = memrefType.getRank();
3373  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3374  int64_t dimSize = memrefType.getDimSize(dim);
3375  // If this is already static dimension, keep it.
3376  if (!ShapedType::isDynamic(dimSize)) {
3377  newShapeConstants.push_back(dimSize);
3378  continue;
3379  }
3380  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3381  if (auto constantIndexOp =
3382  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3383  // Dynamic shape dimension will be folded.
3384  newShapeConstants.push_back(constantIndexOp.value());
3385  } else {
3386  // Dynamic shape dimension not folded; copy operand from old memref.
3387  newShapeConstants.push_back(dimSize);
3388  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3389  }
3390  dynamicDimPos++;
3391  }
3392 
3393  // Create new memref type with constant folded dims.
3394  MemRefType newMemRefType =
3395  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3396  // Nothing new, don't fold.
3397  if (newMemRefType == memrefType)
3398  return failure();
3399 
3400  // Create new ViewOp.
3401  auto newViewOp = rewriter.create<ViewOp>(
3402  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3403  viewOp.getByteShift(), newOperands);
3404  // Insert a cast so we have the same type as the old memref type.
3405  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3406  return success();
3407  }
3408 };
3409 
3410 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3412 
3413  LogicalResult matchAndRewrite(ViewOp viewOp,
3414  PatternRewriter &rewriter) const override {
3415  Value memrefOperand = viewOp.getOperand(0);
3416  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3417  if (!memrefCastOp)
3418  return failure();
3419  Value allocOperand = memrefCastOp.getOperand();
3420  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3421  if (!allocOp)
3422  return failure();
3423  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3424  viewOp.getByteShift(),
3425  viewOp.getSizes());
3426  return success();
3427  }
3428 };
3429 
3430 } // namespace
3431 
3432 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3433  MLIRContext *context) {
3434  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3435 }
3436 
3437 //===----------------------------------------------------------------------===//
3438 // AtomicRMWOp
3439 //===----------------------------------------------------------------------===//
3440 
3442  if (getMemRefType().getRank() != getNumOperands() - 2)
3443  return emitOpError(
3444  "expects the number of subscripts to be equal to memref rank");
3445  switch (getKind()) {
3446  case arith::AtomicRMWKind::addf:
3447  case arith::AtomicRMWKind::maximumf:
3448  case arith::AtomicRMWKind::minimumf:
3449  case arith::AtomicRMWKind::mulf:
3450  if (!llvm::isa<FloatType>(getValue().getType()))
3451  return emitOpError() << "with kind '"
3452  << arith::stringifyAtomicRMWKind(getKind())
3453  << "' expects a floating-point type";
3454  break;
3455  case arith::AtomicRMWKind::addi:
3456  case arith::AtomicRMWKind::maxs:
3457  case arith::AtomicRMWKind::maxu:
3458  case arith::AtomicRMWKind::mins:
3459  case arith::AtomicRMWKind::minu:
3460  case arith::AtomicRMWKind::muli:
3461  case arith::AtomicRMWKind::ori:
3462  case arith::AtomicRMWKind::andi:
3463  if (!llvm::isa<IntegerType>(getValue().getType()))
3464  return emitOpError() << "with kind '"
3465  << arith::stringifyAtomicRMWKind(getKind())
3466  << "' expects an integer type";
3467  break;
3468  default:
3469  break;
3470  }
3471  return success();
3472 }
3473 
3474 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3475  /// atomicrmw(memrefcast) -> atomicrmw
3476  if (succeeded(foldMemRefCast(*this, getValue())))
3477  return getResult();
3478  return OpFoldResult();
3479 }
3480 
3481 //===----------------------------------------------------------------------===//
3482 // TableGen'd op method definitions
3483 //===----------------------------------------------------------------------===//
3484 
3485 #define GET_OP_CLASSES
3486 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:71
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
Definition: MemRefOps.cpp:163
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
Definition: MemRefOps.cpp:115
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1534
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2088
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:451
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3214
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:432
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2789
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1381
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:2943
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
Definition: MemRefOps.cpp:2821
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
Definition: MemRefOps.cpp:176
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1548
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
Definition: MemRefOps.cpp:155
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:921
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3008
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:474
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2800
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:906
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2181
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2301
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:201
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1538
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:131
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:542
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
IndexType getIndexType()
Definition: Builders.cpp:71
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:201
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:222
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:212
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:227
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:836
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:708
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:534
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< Operation::operand_range > getIndices(Operation *op)
Get and set the indices that the given load/store operation is operating on.
Definition: Utils.cpp:10
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:59
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:67
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:2975
MPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:64
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:315
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:369
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:2910
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:521
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:524
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:481
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:484
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2464
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3160
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3161
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3123
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3124
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.