MLIR  19.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 /// Materialize a single constant operation from a given attribute value with
30 /// the desired resultant type.
32  Attribute value, Type type,
33  Location loc) {
34  return arith::ConstantOp::materialize(builder, value, type, loc);
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Common canonicalization pattern support logic
39 //===----------------------------------------------------------------------===//
40 
41 /// This is a common class used for patterns of the form
42 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
43 /// into the root operation directly.
45  bool folded = false;
46  for (OpOperand &operand : op->getOpOperands()) {
47  auto cast = operand.get().getDefiningOp<CastOp>();
48  if (cast && operand.get() != inner &&
49  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50  operand.set(cast.getOperand());
51  folded = true;
52  }
53  }
54  return success(folded);
55 }
56 
57 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
58 /// type.
60  if (auto memref = llvm::dyn_cast<MemRefType>(type))
61  return RankedTensorType::get(memref.getShape(), memref.getElementType());
62  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
63  return UnrankedTensorType::get(memref.getElementType());
64  return NoneType::get(type.getContext());
65 }
66 
68  int64_t dim) {
69  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that infers the constant values from a list of \p values,
91 /// a \p memRefTy, and another helper function \p getAttributes.
92 /// The inferred constant values replace the related `OpFoldResult` in
93 /// \p values.
94 ///
95 /// \note This function shouldn't be used directly, instead, use the
96 /// `getConstifiedMixedXXX` methods from the related operations.
97 ///
98 /// \p getAttributes retuns a list of potentially constant values, as determined
99 /// by \p isDynamic, from the given \p memRefTy. The returned list must have as
100 /// many elements as \p values or be empty.
101 ///
102 /// E.g., consider the following example:
103 /// ```
104 /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
105 /// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
106 /// ```
107 /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
108 /// Now using this helper function with:
109 /// - `values == [2, %dyn_stride]`,
110 /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
111 /// - `getAttributes == getConstantStrides` (i.e., a wrapper around
112 /// `getStridesAndOffset`), and
113 /// - `isDynamic == ShapedType::isDynamic`
114 /// Will yield: `values == [2, 1]`
116  SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
117  MLIRContext *ctxt,
118  llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
119  llvm::function_ref<bool(int64_t)> isDynamic) {
120  SmallVector<int64_t> constValues = getAttributes(memRefTy);
121  Builder builder(ctxt);
122  for (const auto &it : llvm::enumerate(constValues)) {
123  int64_t constValue = it.value();
124  if (!isDynamic(constValue))
125  values[it.index()] = builder.getIndexAttr(constValue);
126  }
127  for (OpFoldResult &ofr : values) {
128  if (ofr.is<Attribute>()) {
129  // FIXME: We shouldn't need to do that, but right now, the static indices
130  // are created with the wrong type: `i64` instead of `index`.
131  // As a result, if we were to keep the attribute as is, we may fail to see
132  // that two attributes are equal because one would have the i64 type and
133  // the other the index type.
134  // The alternative would be to create constant indices with getI64Attr in
135  // this and the previous loop, but it doesn't logically make sense (we are
136  // dealing with indices here) and would only strenghten the inconsistency
137  // around how static indices are created (some places use getI64Attr,
138  // others use getIndexAttr).
139  // The workaround here is to stick to the IndexAttr type for all the
140  // values, hence we recreate the attribute even when it is already static
141  // to make sure the type is consistent.
142  ofr = builder.getIndexAttr(
143  llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt());
144  continue;
145  }
146  std::optional<int64_t> maybeConstant =
147  getConstantIntValue(ofr.get<Value>());
148  if (maybeConstant)
149  ofr = builder.getIndexAttr(*maybeConstant);
150  }
151 }
152 
153 /// Wrapper around `getShape` that conforms to the function signature
154 /// expected for `getAttributes` in `constifyIndexValues`.
155 static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
156  ArrayRef<int64_t> sizes = memRefTy.getShape();
157  return SmallVector<int64_t>(sizes.begin(), sizes.end());
158 }
159 
160 /// Wrapper around `getStridesAndOffset` that returns only the offset and
161 /// conforms to the function signature expected for `getAttributes` in
162 /// `constifyIndexValues`.
163 static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
164  SmallVector<int64_t> strides;
165  int64_t offset;
166  LogicalResult hasStaticInformation =
167  getStridesAndOffset(memrefType, strides, offset);
168  if (failed(hasStaticInformation))
169  return SmallVector<int64_t>();
170  return SmallVector<int64_t>(1, offset);
171 }
172 
173 /// Wrapper around `getStridesAndOffset` that returns only the strides and
174 /// conforms to the function signature expected for `getAttributes` in
175 /// `constifyIndexValues`.
176 static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
177  SmallVector<int64_t> strides;
178  int64_t offset;
179  LogicalResult hasStaticInformation =
180  getStridesAndOffset(memrefType, strides, offset);
181  if (failed(hasStaticInformation))
182  return SmallVector<int64_t>();
183  return strides;
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // AllocOp / AllocaOp
188 //===----------------------------------------------------------------------===//
189 
190 void AllocOp::getAsmResultNames(
191  function_ref<void(Value, StringRef)> setNameFn) {
192  setNameFn(getResult(), "alloc");
193 }
194 
195 void AllocaOp::getAsmResultNames(
196  function_ref<void(Value, StringRef)> setNameFn) {
197  setNameFn(getResult(), "alloca");
198 }
199 
200 template <typename AllocLikeOp>
201 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
202  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
203  "applies to only alloc or alloca");
204  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
205  if (!memRefType)
206  return op.emitOpError("result must be a memref");
207 
208  if (static_cast<int64_t>(op.getDynamicSizes().size()) !=
209  memRefType.getNumDynamicDims())
210  return op.emitOpError("dimension operand count does not equal memref "
211  "dynamic dimension count");
212 
213  unsigned numSymbols = 0;
214  if (!memRefType.getLayout().isIdentity())
215  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
216  if (op.getSymbolOperands().size() != numSymbols)
217  return op.emitOpError("symbol operand count does not equal memref symbol "
218  "count: expected ")
219  << numSymbols << ", got " << op.getSymbolOperands().size();
220 
221  return success();
222 }
223 
225 
227  // An alloca op needs to have an ancestor with an allocation scope trait.
228  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
229  return emitOpError(
230  "requires an ancestor op with AutomaticAllocationScope trait");
231 
232  return verifyAllocLikeOp(*this);
233 }
234 
235 namespace {
236 /// Fold constant dimensions into an alloc like operation.
237 template <typename AllocLikeOp>
238 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
240 
241  LogicalResult matchAndRewrite(AllocLikeOp alloc,
242  PatternRewriter &rewriter) const override {
243  // Check to see if any dimensions operands are constants. If so, we can
244  // substitute and drop them.
245  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
246  APInt constSizeArg;
247  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
248  return false;
249  return constSizeArg.isNonNegative();
250  }))
251  return failure();
252 
253  auto memrefType = alloc.getType();
254 
255  // Ok, we have one or more constant operands. Collect the non-constant ones
256  // and keep track of the resultant memref type to build.
257  SmallVector<int64_t, 4> newShapeConstants;
258  newShapeConstants.reserve(memrefType.getRank());
259  SmallVector<Value, 4> dynamicSizes;
260 
261  unsigned dynamicDimPos = 0;
262  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
263  int64_t dimSize = memrefType.getDimSize(dim);
264  // If this is already static dimension, keep it.
265  if (!ShapedType::isDynamic(dimSize)) {
266  newShapeConstants.push_back(dimSize);
267  continue;
268  }
269  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
270  APInt constSizeArg;
271  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
272  constSizeArg.isNonNegative()) {
273  // Dynamic shape dimension will be folded.
274  newShapeConstants.push_back(constSizeArg.getZExtValue());
275  } else {
276  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
277  newShapeConstants.push_back(ShapedType::kDynamic);
278  dynamicSizes.push_back(dynamicSize);
279  }
280  dynamicDimPos++;
281  }
282 
283  // Create new memref type (which will have fewer dynamic dimensions).
284  MemRefType newMemRefType =
285  MemRefType::Builder(memrefType).setShape(newShapeConstants);
286  assert(static_cast<int64_t>(dynamicSizes.size()) ==
287  newMemRefType.getNumDynamicDims());
288 
289  // Create and insert the alloc op for the new memref.
290  auto newAlloc = rewriter.create<AllocLikeOp>(
291  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
292  alloc.getAlignmentAttr());
293  // Insert a cast so we have the same type as the old alloc.
294  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
295  return success();
296  }
297 };
298 
299 /// Fold alloc operations with no users or only store and dealloc uses.
300 template <typename T>
301 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
303 
304  LogicalResult matchAndRewrite(T alloc,
305  PatternRewriter &rewriter) const override {
306  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
307  if (auto storeOp = dyn_cast<StoreOp>(op))
308  return storeOp.getValue() == alloc;
309  return !isa<DeallocOp>(op);
310  }))
311  return failure();
312 
313  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
314  rewriter.eraseOp(user);
315 
316  rewriter.eraseOp(alloc);
317  return success();
318  }
319 };
320 } // namespace
321 
322 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
323  MLIRContext *context) {
324  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
325 }
326 
327 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
328  MLIRContext *context) {
329  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
330  context);
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // ReallocOp
335 //===----------------------------------------------------------------------===//
336 
338  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
339  MemRefType resultType = getType();
340 
341  // The source memref should have identity layout (or none).
342  if (!sourceType.getLayout().isIdentity())
343  return emitError("unsupported layout for source memref type ")
344  << sourceType;
345 
346  // The result memref should have identity layout (or none).
347  if (!resultType.getLayout().isIdentity())
348  return emitError("unsupported layout for result memref type ")
349  << resultType;
350 
351  // The source memref and the result memref should be in the same memory space.
352  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
353  return emitError("different memory spaces specified for source memref "
354  "type ")
355  << sourceType << " and result memref type " << resultType;
356 
357  // The source memref and the result memref should have the same element type.
358  if (sourceType.getElementType() != resultType.getElementType())
359  return emitError("different element types specified for source memref "
360  "type ")
361  << sourceType << " and result memref type " << resultType;
362 
363  // Verify that we have the dynamic dimension operand when it is needed.
364  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
365  return emitError("missing dimension operand for result type ")
366  << resultType;
367  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
368  return emitError("unnecessary dimension operand for result type ")
369  << resultType;
370 
371  return success();
372 }
373 
374 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
375  MLIRContext *context) {
376  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // AllocaScopeOp
381 //===----------------------------------------------------------------------===//
382 
384  bool printBlockTerminators = false;
385 
386  p << ' ';
387  if (!getResults().empty()) {
388  p << " -> (" << getResultTypes() << ")";
389  printBlockTerminators = true;
390  }
391  p << ' ';
392  p.printRegion(getBodyRegion(),
393  /*printEntryBlockArgs=*/false,
394  /*printBlockTerminators=*/printBlockTerminators);
395  p.printOptionalAttrDict((*this)->getAttrs());
396 }
397 
399  // Create a region for the body.
400  result.regions.reserve(1);
401  Region *bodyRegion = result.addRegion();
402 
403  // Parse optional results type list.
404  if (parser.parseOptionalArrowTypeList(result.types))
405  return failure();
406 
407  // Parse the body region.
408  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
409  return failure();
410  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
411  result.location);
412 
413  // Parse the optional attribute list.
414  if (parser.parseOptionalAttrDict(result.attributes))
415  return failure();
416 
417  return success();
418 }
419 
420 void AllocaScopeOp::getSuccessorRegions(
422  if (!point.isParent()) {
423  regions.push_back(RegionSuccessor(getResults()));
424  return;
425  }
426 
427  regions.push_back(RegionSuccessor(&getBodyRegion()));
428 }
429 
430 /// Given an operation, return whether this op is guaranteed to
431 /// allocate an AutomaticAllocationScopeResource
433  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
434  if (!interface)
435  return false;
436  for (auto res : op->getResults()) {
437  if (auto effect =
438  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
439  if (isa<SideEffects::AutomaticAllocationScopeResource>(
440  effect->getResource()))
441  return true;
442  }
443  }
444  return false;
445 }
446 
447 /// Given an operation, return whether this op itself could
448 /// allocate an AutomaticAllocationScopeResource. Note that
449 /// this will not check whether an operation contained within
450 /// the op can allocate.
452  // This op itself doesn't create a stack allocation,
453  // the inner allocation should be handled separately.
455  return false;
456  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
457  if (!interface)
458  return true;
459  for (auto res : op->getResults()) {
460  if (auto effect =
461  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
462  if (isa<SideEffects::AutomaticAllocationScopeResource>(
463  effect->getResource()))
464  return true;
465  }
466  }
467  return false;
468 }
469 
470 /// Return whether this op is the last non terminating op
471 /// in a region. That is to say, it is in a one-block region
472 /// and is only followed by a terminator. This prevents
473 /// extending the lifetime of allocations.
475  return op->getNextNode() == op->getBlock()->getTerminator() &&
476  op->getParentRegion()->getBlocks().size() == 1;
477 }
478 
479 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
480 /// or it contains no allocation.
481 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
483 
484  LogicalResult matchAndRewrite(AllocaScopeOp op,
485  PatternRewriter &rewriter) const override {
486  bool hasPotentialAlloca =
487  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
488  if (alloc == op)
489  return WalkResult::advance();
491  return WalkResult::interrupt();
492  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
493  return WalkResult::skip();
494  return WalkResult::advance();
495  }).wasInterrupted();
496 
497  // If this contains no potential allocation, it is always legal to
498  // inline. Otherwise, consider two conditions:
499  if (hasPotentialAlloca) {
500  // If the parent isn't an allocation scope, or we are not the last
501  // non-terminator op in the parent, we will extend the lifetime.
503  return failure();
504  if (!lastNonTerminatorInRegion(op))
505  return failure();
506  }
507 
508  Block *block = &op.getRegion().front();
509  Operation *terminator = block->getTerminator();
510  ValueRange results = terminator->getOperands();
511  rewriter.inlineBlockBefore(block, op);
512  rewriter.replaceOp(op, results);
513  rewriter.eraseOp(terminator);
514  return success();
515  }
516 };
517 
518 /// Move allocations into an allocation scope, if it is legal to
519 /// move them (e.g. their operands are available at the location
520 /// the op would be moved to).
521 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
523 
524  LogicalResult matchAndRewrite(AllocaScopeOp op,
525  PatternRewriter &rewriter) const override {
526 
528  return failure();
529 
530  Operation *lastParentWithoutScope = op->getParentOp();
531 
532  if (!lastParentWithoutScope ||
533  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
534  return failure();
535 
536  // Only apply to if this is this last non-terminator
537  // op in the block (lest lifetime be extended) of a one
538  // block region
539  if (!lastNonTerminatorInRegion(op) ||
540  !lastNonTerminatorInRegion(lastParentWithoutScope))
541  return failure();
542 
543  while (!lastParentWithoutScope->getParentOp()
545  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
546  if (!lastParentWithoutScope ||
547  !lastNonTerminatorInRegion(lastParentWithoutScope))
548  return failure();
549  }
550  assert(lastParentWithoutScope->getParentOp()
552 
553  Region *containingRegion = nullptr;
554  for (auto &r : lastParentWithoutScope->getRegions()) {
555  if (r.isAncestor(op->getParentRegion())) {
556  assert(containingRegion == nullptr &&
557  "only one region can contain the op");
558  containingRegion = &r;
559  }
560  }
561  assert(containingRegion && "op must be contained in a region");
562 
563  SmallVector<Operation *> toHoist;
564  op->walk([&](Operation *alloc) {
566  return WalkResult::skip();
567 
568  // If any operand is not defined before the location of
569  // lastParentWithoutScope (i.e. where we would hoist to), skip.
570  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
571  return containingRegion->isAncestor(v.getParentRegion());
572  }))
573  return WalkResult::skip();
574  toHoist.push_back(alloc);
575  return WalkResult::advance();
576  });
577 
578  if (toHoist.empty())
579  return failure();
580  rewriter.setInsertionPoint(lastParentWithoutScope);
581  for (auto *op : toHoist) {
582  auto *cloned = rewriter.clone(*op);
583  rewriter.replaceOp(op, cloned->getResults());
584  }
585  return success();
586  }
587 };
588 
589 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
590  MLIRContext *context) {
591  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // AssumeAlignmentOp
596 //===----------------------------------------------------------------------===//
597 
599  if (!llvm::isPowerOf2_32(getAlignment()))
600  return emitOpError("alignment must be power of 2");
601  return success();
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // CastOp
606 //===----------------------------------------------------------------------===//
607 
608 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
609  setNameFn(getResult(), "cast");
610 }
611 
612 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
613 /// source memref. This is useful to fold a memref.cast into a consuming op
614 /// and implement canonicalization patterns for ops in different dialects that
615 /// may consume the results of memref.cast operations. Such foldable memref.cast
616 /// operations are typically inserted as `view` and `subview` ops are
617 /// canonicalized, to preserve the type compatibility of their uses.
618 ///
619 /// Returns true when all conditions are met:
620 /// 1. source and result are ranked memrefs with strided semantics and same
621 /// element type and rank.
622 /// 2. each of the source's size, offset or stride has more static information
623 /// than the corresponding result's size, offset or stride.
624 ///
625 /// Example 1:
626 /// ```mlir
627 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
628 /// %2 = consumer %1 ... : memref<?x?xf32> ...
629 /// ```
630 ///
631 /// may fold into:
632 ///
633 /// ```mlir
634 /// %2 = consumer %0 ... : memref<8x16xf32> ...
635 /// ```
636 ///
637 /// Example 2:
638 /// ```
639 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
640 /// to memref<?x?xf32>
641 /// consumer %1 : memref<?x?xf32> ...
642 /// ```
643 ///
644 /// may fold into:
645 ///
646 /// ```
647 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
648 /// ```
649 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
650  MemRefType sourceType =
651  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
652  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
653 
654  // Requires ranked MemRefType.
655  if (!sourceType || !resultType)
656  return false;
657 
658  // Requires same elemental type.
659  if (sourceType.getElementType() != resultType.getElementType())
660  return false;
661 
662  // Requires same rank.
663  if (sourceType.getRank() != resultType.getRank())
664  return false;
665 
666  // Only fold casts between strided memref forms.
667  int64_t sourceOffset, resultOffset;
668  SmallVector<int64_t, 4> sourceStrides, resultStrides;
669  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
670  failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
671  return false;
672 
673  // If cast is towards more static sizes along any dimension, don't fold.
674  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
675  auto ss = std::get<0>(it), st = std::get<1>(it);
676  if (ss != st)
677  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
678  return false;
679  }
680 
681  // If cast is towards more static offset along any dimension, don't fold.
682  if (sourceOffset != resultOffset)
683  if (ShapedType::isDynamic(sourceOffset) &&
684  !ShapedType::isDynamic(resultOffset))
685  return false;
686 
687  // If cast is towards more static strides along any dimension, don't fold.
688  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
689  auto ss = std::get<0>(it), st = std::get<1>(it);
690  if (ss != st)
691  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
692  return false;
693  }
694 
695  return true;
696 }
697 
698 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
699  if (inputs.size() != 1 || outputs.size() != 1)
700  return false;
701  Type a = inputs.front(), b = outputs.front();
702  auto aT = llvm::dyn_cast<MemRefType>(a);
703  auto bT = llvm::dyn_cast<MemRefType>(b);
704 
705  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
706  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
707 
708  if (aT && bT) {
709  if (aT.getElementType() != bT.getElementType())
710  return false;
711  if (aT.getLayout() != bT.getLayout()) {
712  int64_t aOffset, bOffset;
713  SmallVector<int64_t, 4> aStrides, bStrides;
714  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
715  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
716  aStrides.size() != bStrides.size())
717  return false;
718 
719  // Strides along a dimension/offset are compatible if the value in the
720  // source memref is static and the value in the target memref is the
721  // same. They are also compatible if either one is dynamic (see
722  // description of MemRefCastOp for details).
723  auto checkCompatible = [](int64_t a, int64_t b) {
724  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
725  };
726  if (!checkCompatible(aOffset, bOffset))
727  return false;
728  for (const auto &aStride : enumerate(aStrides))
729  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
730  return false;
731  }
732  if (aT.getMemorySpace() != bT.getMemorySpace())
733  return false;
734 
735  // They must have the same rank, and any specified dimensions must match.
736  if (aT.getRank() != bT.getRank())
737  return false;
738 
739  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
740  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
741  if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
742  aDim != bDim)
743  return false;
744  }
745  return true;
746  } else {
747  if (!aT && !uaT)
748  return false;
749  if (!bT && !ubT)
750  return false;
751  // Unranked to unranked casting is unsupported
752  if (uaT && ubT)
753  return false;
754 
755  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
756  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
757  if (aEltType != bEltType)
758  return false;
759 
760  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
761  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
762  return aMemSpace == bMemSpace;
763  }
764 
765  return false;
766 }
767 
768 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
769  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // CopyOp
774 //===----------------------------------------------------------------------===//
775 
776 namespace {
777 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
778 /// and element type, the cast can be skipped. Such CastOps only cast the layout
779 /// of the type.
780 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
782 
783  LogicalResult matchAndRewrite(CopyOp copyOp,
784  PatternRewriter &rewriter) const override {
785  bool modified = false;
786 
787  // Check source.
788  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
789  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
790  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
791 
792  if (fromType && toType) {
793  if (fromType.getShape() == toType.getShape() &&
794  fromType.getElementType() == toType.getElementType()) {
795  rewriter.modifyOpInPlace(copyOp, [&] {
796  copyOp.getSourceMutable().assign(castOp.getSource());
797  });
798  modified = true;
799  }
800  }
801  }
802 
803  // Check target.
804  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
805  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
806  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
807 
808  if (fromType && toType) {
809  if (fromType.getShape() == toType.getShape() &&
810  fromType.getElementType() == toType.getElementType()) {
811  rewriter.modifyOpInPlace(copyOp, [&] {
812  copyOp.getTargetMutable().assign(castOp.getSource());
813  });
814  modified = true;
815  }
816  }
817  }
818 
819  return success(modified);
820  }
821 };
822 
823 /// Fold memref.copy(%x, %x).
824 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
826 
827  LogicalResult matchAndRewrite(CopyOp copyOp,
828  PatternRewriter &rewriter) const override {
829  if (copyOp.getSource() != copyOp.getTarget())
830  return failure();
831 
832  rewriter.eraseOp(copyOp);
833  return success();
834  }
835 };
836 } // namespace
837 
838 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
839  MLIRContext *context) {
840  results.add<FoldCopyOfCast, FoldSelfCopy>(context);
841 }
842 
843 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
845  /// copy(memrefcast) -> copy
846  bool folded = false;
847  Operation *op = *this;
848  for (OpOperand &operand : op->getOpOperands()) {
849  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
850  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
851  operand.set(castOp.getOperand());
852  folded = true;
853  }
854  }
855  return success(folded);
856 }
857 
858 //===----------------------------------------------------------------------===//
859 // DeallocOp
860 //===----------------------------------------------------------------------===//
861 
862 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
864  /// dealloc(memrefcast) -> dealloc
865  return foldMemRefCast(*this);
866 }
867 
868 //===----------------------------------------------------------------------===//
869 // DimOp
870 //===----------------------------------------------------------------------===//
871 
872 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
873  setNameFn(getResult(), "dim");
874 }
875 
876 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
877  int64_t index) {
878  auto loc = result.location;
879  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
880  build(builder, result, source, indexValue);
881 }
882 
883 std::optional<int64_t> DimOp::getConstantIndex() {
884  return getConstantIntValue(getIndex());
885 }
886 
887 Speculation::Speculatability DimOp::getSpeculatability() {
888  auto constantIndex = getConstantIndex();
889  if (!constantIndex)
891 
892  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
893  if (!rankedSourceType)
895 
896  if (rankedSourceType.getRank() <= constantIndex)
898 
900 }
901 
902 /// Return a map with key being elements in `vals` and data being number of
903 /// occurences of it. Use std::map, since the `vals` here are strides and the
904 /// dynamic stride value is the same as the tombstone value for
905 /// `DenseMap<int64_t>`.
906 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
907  std::map<int64_t, unsigned> numOccurences;
908  for (auto val : vals)
909  numOccurences[val]++;
910  return numOccurences;
911 }
912 
913 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
914 /// to be a subset of `originalType` with some `1` entries erased, return the
915 /// set of indices that specifies which of the entries of `originalShape` are
916 /// dropped to obtain `reducedShape`.
917 /// This accounts for cases where there are multiple unit-dims, but only a
918 /// subset of those are dropped. For MemRefTypes these can be disambiguated
919 /// using the strides. If a dimension is dropped the stride must be dropped too.
921 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
922  ArrayRef<OpFoldResult> sizes) {
923  llvm::SmallBitVector unusedDims(originalType.getRank());
924  if (originalType.getRank() == reducedType.getRank())
925  return unusedDims;
926 
927  for (const auto &dim : llvm::enumerate(sizes))
928  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
929  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
930  unusedDims.set(dim.index());
931 
932  // Early exit for the case where the number of unused dims matches the number
933  // of ranks reduced.
934  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
935  originalType.getRank())
936  return unusedDims;
937 
938  SmallVector<int64_t> originalStrides, candidateStrides;
939  int64_t originalOffset, candidateOffset;
940  if (failed(
941  getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
942  failed(
943  getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
944  return failure();
945 
946  // For memrefs, a dimension is truly dropped if its corresponding stride is
947  // also dropped. This is particularly important when more than one of the dims
948  // is 1. Track the number of occurences of the strides in the original type
949  // and the candidate type. For each unused dim that stride should not be
950  // present in the candidate type. Note that there could be multiple dimensions
951  // that have the same size. We dont need to exactly figure out which dim
952  // corresponds to which stride, we just need to verify that the number of
953  // reptitions of a stride in the original + number of unused dims with that
954  // stride == number of repititions of a stride in the candidate.
955  std::map<int64_t, unsigned> currUnaccountedStrides =
956  getNumOccurences(originalStrides);
957  std::map<int64_t, unsigned> candidateStridesNumOccurences =
958  getNumOccurences(candidateStrides);
959  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
960  if (!unusedDims.test(dim))
961  continue;
962  int64_t originalStride = originalStrides[dim];
963  if (currUnaccountedStrides[originalStride] >
964  candidateStridesNumOccurences[originalStride]) {
965  // This dim can be treated as dropped.
966  currUnaccountedStrides[originalStride]--;
967  continue;
968  }
969  if (currUnaccountedStrides[originalStride] ==
970  candidateStridesNumOccurences[originalStride]) {
971  // The stride for this is not dropped. Keep as is.
972  unusedDims.reset(dim);
973  continue;
974  }
975  if (currUnaccountedStrides[originalStride] <
976  candidateStridesNumOccurences[originalStride]) {
977  // This should never happen. Cant have a stride in the reduced rank type
978  // that wasnt in the original one.
979  return failure();
980  }
981  }
982 
983  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
984  originalType.getRank())
985  return failure();
986  return unusedDims;
987 }
988 
989 llvm::SmallBitVector SubViewOp::getDroppedDims() {
990  MemRefType sourceType = getSourceType();
991  MemRefType resultType = getType();
993  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
994  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
995  return *unusedDims;
996 }
997 
998 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
999  // All forms of folding require a known index.
1000  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1001  if (!index)
1002  return {};
1003 
1004  // Folding for unranked types (UnrankedMemRefType) is not supported.
1005  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1006  if (!memrefType)
1007  return {};
1008 
1009  // Out of bound indices produce undefined behavior but are still valid IR.
1010  // Don't choke on them.
1011  int64_t indexVal = index.getInt();
1012  if (indexVal < 0 || indexVal >= memrefType.getRank())
1013  return {};
1014 
1015  // Fold if the shape extent along the given index is known.
1016  if (!memrefType.isDynamicDim(index.getInt())) {
1017  Builder builder(getContext());
1018  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1019  }
1020 
1021  // The size at the given index is now known to be a dynamic size.
1022  unsigned unsignedIndex = index.getValue().getZExtValue();
1023 
1024  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1025  Operation *definingOp = getSource().getDefiningOp();
1026 
1027  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1028  return *(alloc.getDynamicSizes().begin() +
1029  memrefType.getDynamicDimIndex(unsignedIndex));
1030 
1031  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1032  return *(alloca.getDynamicSizes().begin() +
1033  memrefType.getDynamicDimIndex(unsignedIndex));
1034 
1035  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1036  return *(view.getDynamicSizes().begin() +
1037  memrefType.getDynamicDimIndex(unsignedIndex));
1038 
1039  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1040  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1041  unsigned resultIndex = 0;
1042  unsigned sourceRank = subview.getSourceType().getRank();
1043  unsigned sourceIndex = 0;
1044  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1045  if (unusedDims.test(i))
1046  continue;
1047  if (resultIndex == unsignedIndex) {
1048  sourceIndex = i;
1049  break;
1050  }
1051  resultIndex++;
1052  }
1053  assert(subview.isDynamicSize(sourceIndex) &&
1054  "expected dynamic subview size");
1055  return subview.getDynamicSize(sourceIndex);
1056  }
1057 
1058  if (auto sizeInterface =
1059  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1060  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1061  "Expected dynamic subview size");
1062  return sizeInterface.getDynamicSize(unsignedIndex);
1063  }
1064 
1065  // dim(memrefcast) -> dim
1066  if (succeeded(foldMemRefCast(*this)))
1067  return getResult();
1068 
1069  return {};
1070 }
1071 
1072 namespace {
1073 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1074 /// operand.
1075 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1077 
1078  LogicalResult matchAndRewrite(DimOp dim,
1079  PatternRewriter &rewriter) const override {
1080  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1081 
1082  if (!reshape)
1083  return failure();
1084 
1085  // Place the load directly after the reshape to ensure that the shape memref
1086  // was not mutated.
1087  rewriter.setInsertionPointAfter(reshape);
1088  Location loc = dim.getLoc();
1089  Value load =
1090  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1091  if (load.getType() != dim.getType())
1092  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1093  rewriter.replaceOp(dim, load);
1094  return success();
1095  }
1096 };
1097 
1098 } // namespace
1099 
1100 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1101  MLIRContext *context) {
1102  results.add<DimOfMemRefReshape>(context);
1103 }
1104 
1105 // ---------------------------------------------------------------------------
1106 // DmaStartOp
1107 // ---------------------------------------------------------------------------
1108 
1109 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1110  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1111  ValueRange destIndices, Value numElements,
1112  Value tagMemRef, ValueRange tagIndices, Value stride,
1113  Value elementsPerStride) {
1114  result.addOperands(srcMemRef);
1115  result.addOperands(srcIndices);
1116  result.addOperands(destMemRef);
1117  result.addOperands(destIndices);
1118  result.addOperands({numElements, tagMemRef});
1119  result.addOperands(tagIndices);
1120  if (stride)
1121  result.addOperands({stride, elementsPerStride});
1122 }
1123 
1125  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1126  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1127  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1128  if (isStrided())
1129  p << ", " << getStride() << ", " << getNumElementsPerStride();
1130 
1131  p.printOptionalAttrDict((*this)->getAttrs());
1132  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1133  << ", " << getTagMemRef().getType();
1134 }
1135 
1136 // Parse DmaStartOp.
1137 // Ex:
1138 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1139 // %tag[%index], %stride, %num_elt_per_stride :
1140 // : memref<3076 x f32, 0>,
1141 // memref<1024 x f32, 2>,
1142 // memref<1 x i32>
1143 //
1145  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1147  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1149  OpAsmParser::UnresolvedOperand numElementsInfo;
1150  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1153 
1154  SmallVector<Type, 3> types;
1155  auto indexType = parser.getBuilder().getIndexType();
1156 
1157  // Parse and resolve the following list of operands:
1158  // *) source memref followed by its indices (in square brackets).
1159  // *) destination memref followed by its indices (in square brackets).
1160  // *) dma size in KiB.
1161  if (parser.parseOperand(srcMemRefInfo) ||
1162  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1163  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1164  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1165  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1166  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1167  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1168  return failure();
1169 
1170  // Parse optional stride and elements per stride.
1171  if (parser.parseTrailingOperandList(strideInfo))
1172  return failure();
1173 
1174  bool isStrided = strideInfo.size() == 2;
1175  if (!strideInfo.empty() && !isStrided) {
1176  return parser.emitError(parser.getNameLoc(),
1177  "expected two stride related operands");
1178  }
1179 
1180  if (parser.parseColonTypeList(types))
1181  return failure();
1182  if (types.size() != 3)
1183  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1184 
1185  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1186  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1187  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1188  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1189  // size should be an index.
1190  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1191  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1192  // tag indices should be index.
1193  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1194  return failure();
1195 
1196  if (isStrided) {
1197  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1198  return failure();
1199  }
1200 
1201  return success();
1202 }
1203 
1205  unsigned numOperands = getNumOperands();
1206 
1207  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1208  // the number of elements.
1209  if (numOperands < 4)
1210  return emitOpError("expected at least 4 operands");
1211 
1212  // Check types of operands. The order of these calls is important: the later
1213  // calls rely on some type properties to compute the operand position.
1214  // 1. Source memref.
1215  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1216  return emitOpError("expected source to be of memref type");
1217  if (numOperands < getSrcMemRefRank() + 4)
1218  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1219  << " operands";
1220  if (!getSrcIndices().empty() &&
1221  !llvm::all_of(getSrcIndices().getTypes(),
1222  [](Type t) { return t.isIndex(); }))
1223  return emitOpError("expected source indices to be of index type");
1224 
1225  // 2. Destination memref.
1226  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1227  return emitOpError("expected destination to be of memref type");
1228  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1229  if (numOperands < numExpectedOperands)
1230  return emitOpError() << "expected at least " << numExpectedOperands
1231  << " operands";
1232  if (!getDstIndices().empty() &&
1233  !llvm::all_of(getDstIndices().getTypes(),
1234  [](Type t) { return t.isIndex(); }))
1235  return emitOpError("expected destination indices to be of index type");
1236 
1237  // 3. Number of elements.
1238  if (!getNumElements().getType().isIndex())
1239  return emitOpError("expected num elements to be of index type");
1240 
1241  // 4. Tag memref.
1242  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1243  return emitOpError("expected tag to be of memref type");
1244  numExpectedOperands += getTagMemRefRank();
1245  if (numOperands < numExpectedOperands)
1246  return emitOpError() << "expected at least " << numExpectedOperands
1247  << " operands";
1248  if (!getTagIndices().empty() &&
1249  !llvm::all_of(getTagIndices().getTypes(),
1250  [](Type t) { return t.isIndex(); }))
1251  return emitOpError("expected tag indices to be of index type");
1252 
1253  // Optional stride-related operands must be either both present or both
1254  // absent.
1255  if (numOperands != numExpectedOperands &&
1256  numOperands != numExpectedOperands + 2)
1257  return emitOpError("incorrect number of operands");
1258 
1259  // 5. Strides.
1260  if (isStrided()) {
1261  if (!getStride().getType().isIndex() ||
1262  !getNumElementsPerStride().getType().isIndex())
1263  return emitOpError(
1264  "expected stride and num elements per stride to be of type index");
1265  }
1266 
1267  return success();
1268 }
1269 
1270 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1271  SmallVectorImpl<OpFoldResult> &results) {
1272  /// dma_start(memrefcast) -> dma_start
1273  return foldMemRefCast(*this);
1274 }
1275 
1276 // ---------------------------------------------------------------------------
1277 // DmaWaitOp
1278 // ---------------------------------------------------------------------------
1279 
1280 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1281  SmallVectorImpl<OpFoldResult> &results) {
1282  /// dma_wait(memrefcast) -> dma_wait
1283  return foldMemRefCast(*this);
1284 }
1285 
1287  // Check that the number of tag indices matches the tagMemRef rank.
1288  unsigned numTagIndices = getTagIndices().size();
1289  unsigned tagMemRefRank = getTagMemRefRank();
1290  if (numTagIndices != tagMemRefRank)
1291  return emitOpError() << "expected tagIndices to have the same number of "
1292  "elements as the tagMemRef rank, expected "
1293  << tagMemRefRank << ", but got " << numTagIndices;
1294  return success();
1295 }
1296 
1297 //===----------------------------------------------------------------------===//
1298 // ExtractAlignedPointerAsIndexOp
1299 //===----------------------------------------------------------------------===//
1300 
1301 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1302  function_ref<void(Value, StringRef)> setNameFn) {
1303  setNameFn(getResult(), "intptr");
1304 }
1305 
1306 //===----------------------------------------------------------------------===//
1307 // ExtractStridedMetadataOp
1308 //===----------------------------------------------------------------------===//
1309 
1310 /// The number and type of the results are inferred from the
1311 /// shape of the source.
1312 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1313  MLIRContext *context, std::optional<Location> location,
1314  ExtractStridedMetadataOp::Adaptor adaptor,
1315  SmallVectorImpl<Type> &inferredReturnTypes) {
1316  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1317  if (!sourceType)
1318  return failure();
1319 
1320  unsigned sourceRank = sourceType.getRank();
1321  IndexType indexType = IndexType::get(context);
1322  auto memrefType =
1323  MemRefType::get({}, sourceType.getElementType(),
1324  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1325  // Base.
1326  inferredReturnTypes.push_back(memrefType);
1327  // Offset.
1328  inferredReturnTypes.push_back(indexType);
1329  // Sizes and strides.
1330  for (unsigned i = 0; i < sourceRank * 2; ++i)
1331  inferredReturnTypes.push_back(indexType);
1332  return success();
1333 }
1334 
1335 void ExtractStridedMetadataOp::getAsmResultNames(
1336  function_ref<void(Value, StringRef)> setNameFn) {
1337  setNameFn(getBaseBuffer(), "base_buffer");
1338  setNameFn(getOffset(), "offset");
1339  // For multi-result to work properly with pretty names and packed syntax `x:3`
1340  // we can only give a pretty name to the first value in the pack.
1341  if (!getSizes().empty()) {
1342  setNameFn(getSizes().front(), "sizes");
1343  setNameFn(getStrides().front(), "strides");
1344  }
1345 }
1346 
1347 /// Helper function to perform the replacement of all constant uses of `values`
1348 /// by a materialized constant extracted from `maybeConstants`.
1349 /// `values` and `maybeConstants` are expected to have the same size.
1350 template <typename Container>
1351 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1352  Container values,
1353  ArrayRef<OpFoldResult> maybeConstants) {
1354  assert(values.size() == maybeConstants.size() &&
1355  " expected values and maybeConstants of the same size");
1356  bool atLeastOneReplacement = false;
1357  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1358  // Don't materialize a constant if there are no uses: this would indice
1359  // infinite loops in the driver.
1360  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1361  continue;
1362  assert(maybeConstant.template is<Attribute>() &&
1363  "The constified value should be either unchanged (i.e., == result) "
1364  "or a constant");
1365  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1366  loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
1367  .getInt());
1368  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1369  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1370  // yet.
1371  op->replaceUsesOfWith(result, constantVal);
1372  atLeastOneReplacement = true;
1373  }
1374  }
1375  return atLeastOneReplacement;
1376 }
1377 
1379 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1380  SmallVectorImpl<OpFoldResult> &results) {
1381  OpBuilder builder(*this);
1382 
1383  bool atLeastOneReplacement = replaceConstantUsesOf(
1384  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1385  getConstifiedMixedOffset());
1386  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1387  getConstifiedMixedSizes());
1388  atLeastOneReplacement |= replaceConstantUsesOf(
1389  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1390 
1391  return success(atLeastOneReplacement);
1392 }
1393 
1394 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1395  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1396  constifyIndexValues(values, getSource().getType(), getContext(),
1397  getConstantSizes, ShapedType::isDynamic);
1398  return values;
1399 }
1400 
1402 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1403  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1404  constifyIndexValues(values, getSource().getType(), getContext(),
1405  getConstantStrides, ShapedType::isDynamic);
1406  return values;
1407 }
1408 
1409 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1410  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1411  SmallVector<OpFoldResult> values(1, offsetOfr);
1412  constifyIndexValues(values, getSource().getType(), getContext(),
1413  getConstantOffset, ShapedType::isDynamic);
1414  return values[0];
1415 }
1416 
1417 //===----------------------------------------------------------------------===//
1418 // GenericAtomicRMWOp
1419 //===----------------------------------------------------------------------===//
1420 
1421 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1422  Value memref, ValueRange ivs) {
1423  OpBuilder::InsertionGuard g(builder);
1424  result.addOperands(memref);
1425  result.addOperands(ivs);
1426 
1427  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1428  Type elementType = memrefType.getElementType();
1429  result.addTypes(elementType);
1430 
1431  Region *bodyRegion = result.addRegion();
1432  builder.createBlock(bodyRegion);
1433  bodyRegion->addArgument(elementType, memref.getLoc());
1434  }
1435 }
1436 
1438  auto &body = getRegion();
1439  if (body.getNumArguments() != 1)
1440  return emitOpError("expected single number of entry block arguments");
1441 
1442  if (getResult().getType() != body.getArgument(0).getType())
1443  return emitOpError("expected block argument of the same type result type");
1444 
1445  bool hasSideEffects =
1446  body.walk([&](Operation *nestedOp) {
1447  if (isMemoryEffectFree(nestedOp))
1448  return WalkResult::advance();
1449  nestedOp->emitError(
1450  "body of 'memref.generic_atomic_rmw' should contain "
1451  "only operations with no side effects");
1452  return WalkResult::interrupt();
1453  })
1454  .wasInterrupted();
1455  return hasSideEffects ? failure() : success();
1456 }
1457 
1459  OperationState &result) {
1461  Type memrefType;
1463 
1464  Type indexType = parser.getBuilder().getIndexType();
1465  if (parser.parseOperand(memref) ||
1467  parser.parseColonType(memrefType) ||
1468  parser.resolveOperand(memref, memrefType, result.operands) ||
1469  parser.resolveOperands(ivs, indexType, result.operands))
1470  return failure();
1471 
1472  Region *body = result.addRegion();
1473  if (parser.parseRegion(*body, {}) ||
1474  parser.parseOptionalAttrDict(result.attributes))
1475  return failure();
1476  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1477  return success();
1478 }
1479 
1481  p << ' ' << getMemref() << "[" << getIndices()
1482  << "] : " << getMemref().getType() << ' ';
1483  p.printRegion(getRegion());
1484  p.printOptionalAttrDict((*this)->getAttrs());
1485 }
1486 
1487 //===----------------------------------------------------------------------===//
1488 // AtomicYieldOp
1489 //===----------------------------------------------------------------------===//
1490 
1492  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1493  Type resultType = getResult().getType();
1494  if (parentType != resultType)
1495  return emitOpError() << "types mismatch between yield op: " << resultType
1496  << " and its parent: " << parentType;
1497  return success();
1498 }
1499 
1500 //===----------------------------------------------------------------------===//
1501 // GlobalOp
1502 //===----------------------------------------------------------------------===//
1503 
1505  TypeAttr type,
1506  Attribute initialValue) {
1507  p << type;
1508  if (!op.isExternal()) {
1509  p << " = ";
1510  if (op.isUninitialized())
1511  p << "uninitialized";
1512  else
1513  p.printAttributeWithoutType(initialValue);
1514  }
1515 }
1516 
1517 static ParseResult
1519  Attribute &initialValue) {
1520  Type type;
1521  if (parser.parseType(type))
1522  return failure();
1523 
1524  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1525  if (!memrefType || !memrefType.hasStaticShape())
1526  return parser.emitError(parser.getNameLoc())
1527  << "type should be static shaped memref, but got " << type;
1528  typeAttr = TypeAttr::get(type);
1529 
1530  if (parser.parseOptionalEqual())
1531  return success();
1532 
1533  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1534  initialValue = UnitAttr::get(parser.getContext());
1535  return success();
1536  }
1537 
1538  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1539  if (parser.parseAttribute(initialValue, tensorType))
1540  return failure();
1541  if (!llvm::isa<ElementsAttr>(initialValue))
1542  return parser.emitError(parser.getNameLoc())
1543  << "initial value should be a unit or elements attribute";
1544  return success();
1545 }
1546 
1548  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1549  if (!memrefType || !memrefType.hasStaticShape())
1550  return emitOpError("type should be static shaped memref, but got ")
1551  << getType();
1552 
1553  // Verify that the initial value, if present, is either a unit attribute or
1554  // an elements attribute.
1555  if (getInitialValue().has_value()) {
1556  Attribute initValue = getInitialValue().value();
1557  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1558  return emitOpError("initial value should be a unit or elements "
1559  "attribute, but got ")
1560  << initValue;
1561 
1562  // Check that the type of the initial value is compatible with the type of
1563  // the global variable.
1564  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1565  Type initType = elementsAttr.getType();
1566  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1567  if (initType != tensorType)
1568  return emitOpError("initial value expected to be of type ")
1569  << tensorType << ", but was of type " << initType;
1570  }
1571  }
1572 
1573  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1574  uint64_t alignment = *alignAttr;
1575 
1576  if (!llvm::isPowerOf2_64(alignment))
1577  return emitError() << "alignment attribute value " << alignment
1578  << " is not a power of 2";
1579  }
1580 
1581  // TODO: verify visibility for declarations.
1582  return success();
1583 }
1584 
1585 ElementsAttr GlobalOp::getConstantInitValue() {
1586  auto initVal = getInitialValue();
1587  if (getConstant() && initVal.has_value())
1588  return llvm::cast<ElementsAttr>(initVal.value());
1589  return {};
1590 }
1591 
1592 //===----------------------------------------------------------------------===//
1593 // GetGlobalOp
1594 //===----------------------------------------------------------------------===//
1595 
1597 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1598  // Verify that the result type is same as the type of the referenced
1599  // memref.global op.
1600  auto global =
1601  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1602  if (!global)
1603  return emitOpError("'")
1604  << getName() << "' does not reference a valid global memref";
1605 
1606  Type resultType = getResult().getType();
1607  if (global.getType() != resultType)
1608  return emitOpError("result type ")
1609  << resultType << " does not match type " << global.getType()
1610  << " of the global memref @" << getName();
1611  return success();
1612 }
1613 
1614 //===----------------------------------------------------------------------===//
1615 // LoadOp
1616 //===----------------------------------------------------------------------===//
1617 
1619  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1620  return emitOpError("incorrect number of indices for load, expected ")
1621  << getMemRefType().getRank() << " but got " << getIndices().size();
1622  }
1623  return success();
1624 }
1625 
1626 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1627  /// load(memrefcast) -> load
1628  if (succeeded(foldMemRefCast(*this)))
1629  return getResult();
1630  return OpFoldResult();
1631 }
1632 
1633 //===----------------------------------------------------------------------===//
1634 // MemorySpaceCastOp
1635 //===----------------------------------------------------------------------===//
1636 
1637 void MemorySpaceCastOp::getAsmResultNames(
1638  function_ref<void(Value, StringRef)> setNameFn) {
1639  setNameFn(getResult(), "memspacecast");
1640 }
1641 
1642 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1643  if (inputs.size() != 1 || outputs.size() != 1)
1644  return false;
1645  Type a = inputs.front(), b = outputs.front();
1646  auto aT = llvm::dyn_cast<MemRefType>(a);
1647  auto bT = llvm::dyn_cast<MemRefType>(b);
1648 
1649  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1650  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1651 
1652  if (aT && bT) {
1653  if (aT.getElementType() != bT.getElementType())
1654  return false;
1655  if (aT.getLayout() != bT.getLayout())
1656  return false;
1657  if (aT.getShape() != bT.getShape())
1658  return false;
1659  return true;
1660  }
1661  if (uaT && ubT) {
1662  return uaT.getElementType() == ubT.getElementType();
1663  }
1664  return false;
1665 }
1666 
1667 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1668  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1669  // t2)
1670  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1671  getSourceMutable().assign(parentCast.getSource());
1672  return getResult();
1673  }
1674  return Value{};
1675 }
1676 
1677 //===----------------------------------------------------------------------===//
1678 // PrefetchOp
1679 //===----------------------------------------------------------------------===//
1680 
1682  p << " " << getMemref() << '[';
1684  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1685  p << ", locality<" << getLocalityHint();
1686  p << ">, " << (getIsDataCache() ? "data" : "instr");
1688  (*this)->getAttrs(),
1689  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1690  p << " : " << getMemRefType();
1691 }
1692 
1694  OpAsmParser::UnresolvedOperand memrefInfo;
1696  IntegerAttr localityHint;
1697  MemRefType type;
1698  StringRef readOrWrite, cacheType;
1699 
1700  auto indexTy = parser.getBuilder().getIndexType();
1701  auto i32Type = parser.getBuilder().getIntegerType(32);
1702  if (parser.parseOperand(memrefInfo) ||
1703  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1704  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1705  parser.parseComma() || parser.parseKeyword("locality") ||
1706  parser.parseLess() ||
1707  parser.parseAttribute(localityHint, i32Type, "localityHint",
1708  result.attributes) ||
1709  parser.parseGreater() || parser.parseComma() ||
1710  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1711  parser.resolveOperand(memrefInfo, type, result.operands) ||
1712  parser.resolveOperands(indexInfo, indexTy, result.operands))
1713  return failure();
1714 
1715  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1716  return parser.emitError(parser.getNameLoc(),
1717  "rw specifier has to be 'read' or 'write'");
1718  result.addAttribute(
1719  PrefetchOp::getIsWriteAttrStrName(),
1720  parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1721 
1722  if (!cacheType.equals("data") && !cacheType.equals("instr"))
1723  return parser.emitError(parser.getNameLoc(),
1724  "cache type has to be 'data' or 'instr'");
1725 
1726  result.addAttribute(
1727  PrefetchOp::getIsDataCacheAttrStrName(),
1728  parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1729 
1730  return success();
1731 }
1732 
1734  if (getNumOperands() != 1 + getMemRefType().getRank())
1735  return emitOpError("too few indices");
1736 
1737  return success();
1738 }
1739 
1740 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1741  SmallVectorImpl<OpFoldResult> &results) {
1742  // prefetch(memrefcast) -> prefetch
1743  return foldMemRefCast(*this);
1744 }
1745 
1746 //===----------------------------------------------------------------------===//
1747 // RankOp
1748 //===----------------------------------------------------------------------===//
1749 
1750 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1751  // Constant fold rank when the rank of the operand is known.
1752  auto type = getOperand().getType();
1753  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1754  if (shapedType && shapedType.hasRank())
1755  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1756  return IntegerAttr();
1757 }
1758 
1759 //===----------------------------------------------------------------------===//
1760 // ReinterpretCastOp
1761 //===----------------------------------------------------------------------===//
1762 
1763 void ReinterpretCastOp::getAsmResultNames(
1764  function_ref<void(Value, StringRef)> setNameFn) {
1765  setNameFn(getResult(), "reinterpret_cast");
1766 }
1767 
1768 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1769 /// `staticSizes` and `staticStrides` are automatically filled with
1770 /// source-memref-rank sentinel values that encode dynamic entries.
1771 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1772  MemRefType resultType, Value source,
1773  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1774  ArrayRef<OpFoldResult> strides,
1775  ArrayRef<NamedAttribute> attrs) {
1776  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1777  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1778  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1779  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1780  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1781  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1782  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1783  b.getDenseI64ArrayAttr(staticSizes),
1784  b.getDenseI64ArrayAttr(staticStrides));
1785  result.addAttributes(attrs);
1786 }
1787 
1788 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1789  MemRefType resultType, Value source,
1790  int64_t offset, ArrayRef<int64_t> sizes,
1791  ArrayRef<int64_t> strides,
1792  ArrayRef<NamedAttribute> attrs) {
1793  SmallVector<OpFoldResult> sizeValues =
1794  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1795  return b.getI64IntegerAttr(v);
1796  }));
1797  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1798  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1799  return b.getI64IntegerAttr(v);
1800  }));
1801  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1802  strideValues, attrs);
1803 }
1804 
1805 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1806  MemRefType resultType, Value source, Value offset,
1807  ValueRange sizes, ValueRange strides,
1808  ArrayRef<NamedAttribute> attrs) {
1809  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1810  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1811  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1812  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1813  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1814 }
1815 
1816 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1817 // completed automatically, like we have for subview and extract_slice.
1819  // The source and result memrefs should be in the same memory space.
1820  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1821  auto resultType = llvm::cast<MemRefType>(getType());
1822  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1823  return emitError("different memory spaces specified for source type ")
1824  << srcType << " and result memref type " << resultType;
1825  if (srcType.getElementType() != resultType.getElementType())
1826  return emitError("different element types specified for source type ")
1827  << srcType << " and result memref type " << resultType;
1828 
1829  // Match sizes in result memref type and in static_sizes attribute.
1830  for (auto [idx, resultSize, expectedSize] :
1831  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1832  if (!ShapedType::isDynamic(resultSize) &&
1833  !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1834  return emitError("expected result type with size = ")
1835  << expectedSize << " instead of " << resultSize
1836  << " in dim = " << idx;
1837  }
1838 
1839  // Match offset and strides in static_offset and static_strides attributes. If
1840  // result memref type has no affine map specified, this will assume an
1841  // identity layout.
1842  int64_t resultOffset;
1843  SmallVector<int64_t, 4> resultStrides;
1844  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1845  return emitError("expected result type to have strided layout but found ")
1846  << resultType;
1847 
1848  // Match offset in result memref type and in static_offsets attribute.
1849  int64_t expectedOffset = getStaticOffsets().front();
1850  if (!ShapedType::isDynamic(resultOffset) &&
1851  !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
1852  return emitError("expected result type with offset = ")
1853  << expectedOffset << " instead of " << resultOffset;
1854 
1855  // Match strides in result memref type and in static_strides attribute.
1856  for (auto [idx, resultStride, expectedStride] :
1857  llvm::enumerate(resultStrides, getStaticStrides())) {
1858  if (!ShapedType::isDynamic(resultStride) &&
1859  !ShapedType::isDynamic(expectedStride) &&
1860  resultStride != expectedStride)
1861  return emitError("expected result type with stride = ")
1862  << expectedStride << " instead of " << resultStride
1863  << " in dim = " << idx;
1864  }
1865 
1866  return success();
1867 }
1868 
1869 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1870  Value src = getSource();
1871  auto getPrevSrc = [&]() -> Value {
1872  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1873  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1874  return prev.getSource();
1875 
1876  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1877  if (auto prev = src.getDefiningOp<CastOp>())
1878  return prev.getSource();
1879 
1880  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1881  // are 0.
1882  if (auto prev = src.getDefiningOp<SubViewOp>())
1883  if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1884  return isConstantIntValue(val, 0);
1885  }))
1886  return prev.getSource();
1887 
1888  return nullptr;
1889  };
1890 
1891  if (auto prevSrc = getPrevSrc()) {
1892  getSourceMutable().assign(prevSrc);
1893  return getResult();
1894  }
1895 
1896  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1897  if (!ShapedType::isDynamicShape(getType().getShape()) &&
1898  src.getType() == getType() && getStaticOffsets().front() == 0) {
1899  return src;
1900  }
1901 
1902  return nullptr;
1903 }
1904 
1905 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1907  constifyIndexValues(values, getType(), getContext(), getConstantSizes,
1908  ShapedType::isDynamic);
1909  return values;
1910 }
1911 
1912 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1913  SmallVector<OpFoldResult> values = getMixedStrides();
1914  constifyIndexValues(values, getType(), getContext(), getConstantStrides,
1915  ShapedType::isDynamic);
1916  return values;
1917 }
1918 
1919 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1920  SmallVector<OpFoldResult> values = getMixedOffsets();
1921  assert(values.size() == 1 &&
1922  "reinterpret_cast must have one and only one offset");
1923  constifyIndexValues(values, getType(), getContext(), getConstantOffset,
1924  ShapedType::isDynamic);
1925  return values[0];
1926 }
1927 
1928 namespace {
1929 /// Replace the sequence:
1930 /// ```
1931 /// base, offset, sizes, strides = extract_strided_metadata src
1932 /// dst = reinterpret_cast base to offset, sizes, strides
1933 /// ```
1934 /// With
1935 ///
1936 /// ```
1937 /// dst = memref.cast src
1938 /// ```
1939 ///
1940 /// Note: The cast operation is only inserted when the type of dst and src
1941 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
1942 ///
1943 /// This pattern also matches when the offset, sizes, and strides don't come
1944 /// directly from the `extract_strided_metadata`'s results but it can be
1945 /// statically proven that they would hold the same values.
1946 ///
1947 /// For instance, the following sequence would be replaced:
1948 /// ```
1949 /// base, offset, sizes, strides =
1950 /// extract_strided_metadata memref : memref<3x4xty>
1951 /// dst = reinterpret_cast base to 0, [3, 4], strides
1952 /// ```
1953 /// Because we know (thanks to the type of the input memref) that variable
1954 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
1955 ///
1956 /// Similarly, the following sequence would be replaced:
1957 /// ```
1958 /// c0 = arith.constant 0
1959 /// c4 = arith.constant 4
1960 /// base, offset, sizes, strides =
1961 /// extract_strided_metadata memref : memref<3x4xty>
1962 /// dst = reinterpret_cast base to c0, [3, c4], strides
1963 /// ```
1964 /// Because we know that `offset`and `c0` will hold 0
1965 /// and `c4` will hold 4.
1966 struct ReinterpretCastOpExtractStridedMetadataFolder
1967  : public OpRewritePattern<ReinterpretCastOp> {
1968 public:
1970 
1971  LogicalResult matchAndRewrite(ReinterpretCastOp op,
1972  PatternRewriter &rewriter) const override {
1973  auto extractStridedMetadata =
1974  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
1975  if (!extractStridedMetadata)
1976  return failure();
1977  // Check if the reinterpret cast reconstructs a memref with the exact same
1978  // properties as the extract strided metadata.
1979 
1980  // First, check that the strides are the same.
1981  SmallVector<OpFoldResult> extractStridesOfr =
1982  extractStridedMetadata.getConstifiedMixedStrides();
1983  SmallVector<OpFoldResult> reinterpretStridesOfr =
1984  op.getConstifiedMixedStrides();
1985  if (extractStridesOfr.size() != reinterpretStridesOfr.size())
1986  return failure();
1987 
1988  unsigned rank = op.getType().getRank();
1989  for (unsigned i = 0; i < rank; ++i) {
1990  if (extractStridesOfr[i] != reinterpretStridesOfr[i])
1991  return failure();
1992  }
1993 
1994  // Second, check the sizes.
1995  assert(extractStridedMetadata.getSizes().size() ==
1996  op.getMixedSizes().size() &&
1997  "Strides and sizes rank must match");
1998  SmallVector<OpFoldResult> extractSizesOfr =
1999  extractStridedMetadata.getConstifiedMixedSizes();
2000  SmallVector<OpFoldResult> reinterpretSizesOfr =
2001  op.getConstifiedMixedSizes();
2002  for (unsigned i = 0; i < rank; ++i) {
2003  if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2004  return failure();
2005  }
2006  // Finally, check the offset.
2007  assert(op.getMixedOffsets().size() == 1 &&
2008  "reinterpret_cast with more than one offset should have been "
2009  "rejected by the verifier");
2010  OpFoldResult extractOffsetOfr =
2011  extractStridedMetadata.getConstifiedMixedOffset();
2012  OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2013  if (extractOffsetOfr != reinterpretOffsetOfr)
2014  return failure();
2015 
2016  // At this point, we know that the back and forth between extract strided
2017  // metadata and reinterpret cast is a noop. However, the final type of the
2018  // reinterpret cast may not be exactly the same as the original memref.
2019  // E.g., it could be changing a dimension from static to dynamic. Check that
2020  // here and add a cast if necessary.
2021  Type srcTy = extractStridedMetadata.getSource().getType();
2022  if (srcTy == op.getResult().getType())
2023  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2024  else
2025  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2026  extractStridedMetadata.getSource());
2027 
2028  return success();
2029  }
2030 };
2031 } // namespace
2032 
2033 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2034  MLIRContext *context) {
2035  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2036 }
2037 
2038 //===----------------------------------------------------------------------===//
2039 // Reassociative reshape ops
2040 //===----------------------------------------------------------------------===//
2041 
2042 void CollapseShapeOp::getAsmResultNames(
2043  function_ref<void(Value, StringRef)> setNameFn) {
2044  setNameFn(getResult(), "collapse_shape");
2045 }
2046 
2047 void ExpandShapeOp::getAsmResultNames(
2048  function_ref<void(Value, StringRef)> setNameFn) {
2049  setNameFn(getResult(), "expand_shape");
2050 }
2051 
2052 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2053 /// result and operand. Layout maps are verified separately.
2054 ///
2055 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2056 /// allowed in a reassocation group.
2057 static LogicalResult
2059  ArrayRef<int64_t> expandedShape,
2060  ArrayRef<ReassociationIndices> reassociation,
2061  bool allowMultipleDynamicDimsPerGroup) {
2062  // There must be one reassociation group per collapsed dimension.
2063  if (collapsedShape.size() != reassociation.size())
2064  return op->emitOpError("invalid number of reassociation groups: found ")
2065  << reassociation.size() << ", expected " << collapsedShape.size();
2066 
2067  // The next expected expanded dimension index (while iterating over
2068  // reassociation indices).
2069  int64_t nextDim = 0;
2070  for (const auto &it : llvm::enumerate(reassociation)) {
2071  ReassociationIndices group = it.value();
2072  int64_t collapsedDim = it.index();
2073 
2074  bool foundDynamic = false;
2075  for (int64_t expandedDim : group) {
2076  if (expandedDim != nextDim++)
2077  return op->emitOpError("reassociation indices must be contiguous");
2078 
2079  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2080  return op->emitOpError("reassociation index ")
2081  << expandedDim << " is out of bounds";
2082 
2083  // Check if there are multiple dynamic dims in a reassociation group.
2084  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2085  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2086  return op->emitOpError(
2087  "at most one dimension in a reassociation group may be dynamic");
2088  foundDynamic = true;
2089  }
2090  }
2091 
2092  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2093  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2094  return op->emitOpError("collapsed dim (")
2095  << collapsedDim
2096  << ") must be dynamic if and only if reassociation group is "
2097  "dynamic";
2098 
2099  // If all dims in the reassociation group are static, the size of the
2100  // collapsed dim can be verified.
2101  if (!foundDynamic) {
2102  int64_t groupSize = 1;
2103  for (int64_t expandedDim : group)
2104  groupSize *= expandedShape[expandedDim];
2105  if (groupSize != collapsedShape[collapsedDim])
2106  return op->emitOpError("collapsed dim size (")
2107  << collapsedShape[collapsedDim]
2108  << ") must equal reassociation group size (" << groupSize << ")";
2109  }
2110  }
2111 
2112  if (collapsedShape.empty()) {
2113  // Rank 0: All expanded dimensions must be 1.
2114  for (int64_t d : expandedShape)
2115  if (d != 1)
2116  return op->emitOpError(
2117  "rank 0 memrefs can only be extended/collapsed with/from ones");
2118  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2119  // Rank >= 1: Number of dimensions among all reassociation groups must match
2120  // the result memref rank.
2121  return op->emitOpError("expanded rank (")
2122  << expandedShape.size()
2123  << ") inconsistent with number of reassociation indices (" << nextDim
2124  << ")";
2125  }
2126 
2127  return success();
2128 }
2129 
2130 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2131  return getSymbolLessAffineMaps(getReassociationExprs());
2132 }
2133 
2134 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2136  getReassociationIndices());
2137 }
2138 
2139 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2140  return getSymbolLessAffineMaps(getReassociationExprs());
2141 }
2142 
2143 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2145  getReassociationIndices());
2146 }
2147 
2148 /// Compute the layout map after expanding a given source MemRef type with the
2149 /// specified reassociation indices.
2151 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2152  ArrayRef<ReassociationIndices> reassociation) {
2153  int64_t srcOffset;
2154  SmallVector<int64_t> srcStrides;
2155  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2156  return failure();
2157  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2158 
2159  // 1-1 mapping between srcStrides and reassociation packs.
2160  // Each srcStride starts with the given value and gets expanded according to
2161  // the proper entries in resultShape.
2162  // Example:
2163  // srcStrides = [10000, 1 , 100 ],
2164  // reassociations = [ [0], [1], [2, 3, 4]],
2165  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2166  // -> For the purpose of stride calculation, the useful sizes are:
2167  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2168  // resultStrides = [10000, 1, 600, 200, 100]
2169  // Note that a stride does not get expanded along the first entry of each
2170  // shape pack.
2171  SmallVector<int64_t> reverseResultStrides;
2172  reverseResultStrides.reserve(resultShape.size());
2173  unsigned shapeIndex = resultShape.size() - 1;
2174  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2175  ReassociationIndices reassoc = std::get<0>(it);
2176  int64_t currentStrideToExpand = std::get<1>(it);
2177  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2178  reverseResultStrides.push_back(currentStrideToExpand);
2179  currentStrideToExpand =
2180  (SaturatedInteger::wrap(currentStrideToExpand) *
2181  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2182  .asInteger();
2183  }
2184  }
2185  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2186  resultStrides.resize(resultShape.size(), 1);
2187  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2188 }
2189 
2190 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2191  MemRefType srcType, ArrayRef<int64_t> resultShape,
2192  ArrayRef<ReassociationIndices> reassociation) {
2193  if (srcType.getLayout().isIdentity()) {
2194  // If the source is contiguous (i.e., no layout map specified), so is the
2195  // result.
2196  MemRefLayoutAttrInterface layout;
2197  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2198  srcType.getMemorySpace());
2199  }
2200 
2201  // Source may not be contiguous. Compute the layout map.
2202  FailureOr<StridedLayoutAttr> computedLayout =
2203  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2204  if (failed(computedLayout))
2205  return failure();
2206  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2207  srcType.getMemorySpace());
2208 }
2209 
2210 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2211  ArrayRef<int64_t> resultShape, Value src,
2212  ArrayRef<ReassociationIndices> reassociation) {
2213  // Only ranked memref source values are supported.
2214  auto srcType = llvm::cast<MemRefType>(src.getType());
2215  FailureOr<MemRefType> resultType =
2216  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2217  // Failure of this assertion usually indicates a problem with the source
2218  // type, e.g., could not get strides/offset.
2219  assert(succeeded(resultType) && "could not compute layout");
2220  build(builder, result, *resultType, src, reassociation);
2221 }
2222 
2224  MemRefType srcType = getSrcType();
2225  MemRefType resultType = getResultType();
2226 
2227  if (srcType.getRank() >= resultType.getRank())
2228  return emitOpError("expected rank expansion, but found source rank ")
2229  << srcType.getRank() << " >= result rank " << resultType.getRank();
2230 
2231  // Verify result shape.
2232  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2233  resultType.getShape(),
2234  getReassociationIndices(),
2235  /*allowMultipleDynamicDimsPerGroup=*/false)))
2236  return failure();
2237 
2238  // Compute expected result type (including layout map).
2239  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2240  srcType, resultType.getShape(), getReassociationIndices());
2241  if (failed(expectedResultType))
2242  return emitOpError("invalid source layout map");
2243 
2244  // Check actual result type.
2245  if (*expectedResultType != resultType)
2246  return emitOpError("expected expanded type to be ")
2247  << *expectedResultType << " but found " << resultType;
2248 
2249  return success();
2250 }
2251 
2252 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2253  MLIRContext *context) {
2256  context);
2257 }
2258 
2259 /// Compute the layout map after collapsing a given source MemRef type with the
2260 /// specified reassociation indices.
2261 ///
2262 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2263 /// not possible to check this by inspecting a MemRefType in the general case.
2264 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2265 /// be valid (and thus accepted by this function) unless `strict = true`.
2267 computeCollapsedLayoutMap(MemRefType srcType,
2268  ArrayRef<ReassociationIndices> reassociation,
2269  bool strict = false) {
2270  int64_t srcOffset;
2271  SmallVector<int64_t> srcStrides;
2272  auto srcShape = srcType.getShape();
2273  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2274  return failure();
2275 
2276  // The result stride of a reassociation group is the stride of the last entry
2277  // of the reassociation. (TODO: Should be the minimum stride in the
2278  // reassociation because strides are not necessarily sorted. E.g., when using
2279  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2280  // strides are meaningless and could have any arbitrary value.
2281  SmallVector<int64_t> resultStrides;
2282  resultStrides.reserve(reassociation.size());
2283  for (const ReassociationIndices &reassoc : reassociation) {
2284  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2285  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2286  ref = ref.drop_back();
2287  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2288  resultStrides.push_back(srcStrides[ref.back()]);
2289  } else {
2290  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2291  // the corresponding stride may have to be skipped. (See above comment.)
2292  // Therefore, the result stride cannot be statically determined and must
2293  // be dynamic.
2294  resultStrides.push_back(ShapedType::kDynamic);
2295  }
2296  }
2297 
2298  // Validate that each reassociation group is contiguous.
2299  unsigned resultStrideIndex = resultStrides.size() - 1;
2300  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2301  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2302  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2303  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2304  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2305 
2306  // Both source and result stride must have the same static value. In that
2307  // case, we can be sure, that the dimensions are collapsible (because they
2308  // are contiguous).
2309  // If `strict = false` (default during op verification), we accept cases
2310  // where one or both strides are dynamic. This is best effort: We reject
2311  // ops where obviously non-contiguous dims are collapsed, but accept ops
2312  // where we cannot be sure statically. Such ops may fail at runtime. See
2313  // the op documentation for details.
2314  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2315  if (strict && (stride.saturated || srcStride.saturated))
2316  return failure();
2317 
2318  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2319  return failure();
2320  }
2321  }
2322  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2323 }
2324 
2325 bool CollapseShapeOp::isGuaranteedCollapsible(
2326  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2327  // MemRefs with identity layout are always collapsible.
2328  if (srcType.getLayout().isIdentity())
2329  return true;
2330 
2331  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2332  /*strict=*/true));
2333 }
2334 
2335 MemRefType CollapseShapeOp::computeCollapsedType(
2336  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2337  SmallVector<int64_t> resultShape;
2338  resultShape.reserve(reassociation.size());
2339  for (const ReassociationIndices &group : reassociation) {
2340  auto groupSize = SaturatedInteger::wrap(1);
2341  for (int64_t srcDim : group)
2342  groupSize =
2343  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2344  resultShape.push_back(groupSize.asInteger());
2345  }
2346 
2347  if (srcType.getLayout().isIdentity()) {
2348  // If the source is contiguous (i.e., no layout map specified), so is the
2349  // result.
2350  MemRefLayoutAttrInterface layout;
2351  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2352  srcType.getMemorySpace());
2353  }
2354 
2355  // Source may not be fully contiguous. Compute the layout map.
2356  // Note: Dimensions that are collapsed into a single dim are assumed to be
2357  // contiguous.
2358  FailureOr<StridedLayoutAttr> computedLayout =
2359  computeCollapsedLayoutMap(srcType, reassociation);
2360  assert(succeeded(computedLayout) &&
2361  "invalid source layout map or collapsing non-contiguous dims");
2362  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2363  srcType.getMemorySpace());
2364 }
2365 
2366 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2367  ArrayRef<ReassociationIndices> reassociation,
2368  ArrayRef<NamedAttribute> attrs) {
2369  auto srcType = llvm::cast<MemRefType>(src.getType());
2370  MemRefType resultType =
2371  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2372  build(b, result, resultType, src, attrs);
2374  getReassociationIndicesAttribute(b, reassociation));
2375 }
2376 
2378  MemRefType srcType = getSrcType();
2379  MemRefType resultType = getResultType();
2380 
2381  if (srcType.getRank() <= resultType.getRank())
2382  return emitOpError("expected rank reduction, but found source rank ")
2383  << srcType.getRank() << " <= result rank " << resultType.getRank();
2384 
2385  // Verify result shape.
2386  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2387  srcType.getShape(), getReassociationIndices(),
2388  /*allowMultipleDynamicDimsPerGroup=*/true)))
2389  return failure();
2390 
2391  // Compute expected result type (including layout map).
2392  MemRefType expectedResultType;
2393  if (srcType.getLayout().isIdentity()) {
2394  // If the source is contiguous (i.e., no layout map specified), so is the
2395  // result.
2396  MemRefLayoutAttrInterface layout;
2397  expectedResultType =
2398  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2399  srcType.getMemorySpace());
2400  } else {
2401  // Source may not be fully contiguous. Compute the layout map.
2402  // Note: Dimensions that are collapsed into a single dim are assumed to be
2403  // contiguous.
2404  FailureOr<StridedLayoutAttr> computedLayout =
2405  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2406  if (failed(computedLayout))
2407  return emitOpError(
2408  "invalid source layout map or collapsing non-contiguous dims");
2409  expectedResultType =
2410  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2411  *computedLayout, srcType.getMemorySpace());
2412  }
2413 
2414  if (expectedResultType != resultType)
2415  return emitOpError("expected collapsed type to be ")
2416  << expectedResultType << " but found " << resultType;
2417 
2418  return success();
2419 }
2420 
2422  : public OpRewritePattern<CollapseShapeOp> {
2423 public:
2425 
2426  LogicalResult matchAndRewrite(CollapseShapeOp op,
2427  PatternRewriter &rewriter) const override {
2428  auto cast = op.getOperand().getDefiningOp<CastOp>();
2429  if (!cast)
2430  return failure();
2431 
2432  if (!CastOp::canFoldIntoConsumerOp(cast))
2433  return failure();
2434 
2435  Type newResultType = CollapseShapeOp::computeCollapsedType(
2436  llvm::cast<MemRefType>(cast.getOperand().getType()),
2437  op.getReassociationIndices());
2438 
2439  if (newResultType == op.getResultType()) {
2440  rewriter.modifyOpInPlace(
2441  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2442  } else {
2443  Value newOp = rewriter.create<CollapseShapeOp>(
2444  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2445  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2446  }
2447  return success();
2448  }
2449 };
2450 
2451 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2452  MLIRContext *context) {
2456 }
2457 
2458 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2459  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2460  adaptor.getOperands());
2461 }
2462 
2463 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2464  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2465  adaptor.getOperands());
2466 }
2467 
2468 //===----------------------------------------------------------------------===//
2469 // ReshapeOp
2470 //===----------------------------------------------------------------------===//
2471 
2472 void ReshapeOp::getAsmResultNames(
2473  function_ref<void(Value, StringRef)> setNameFn) {
2474  setNameFn(getResult(), "reshape");
2475 }
2476 
2478  Type operandType = getSource().getType();
2479  Type resultType = getResult().getType();
2480 
2481  Type operandElementType =
2482  llvm::cast<ShapedType>(operandType).getElementType();
2483  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2484  if (operandElementType != resultElementType)
2485  return emitOpError("element types of source and destination memref "
2486  "types should be the same");
2487 
2488  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2489  if (!operandMemRefType.getLayout().isIdentity())
2490  return emitOpError("source memref type should have identity affine map");
2491 
2492  int64_t shapeSize =
2493  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2494  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2495  if (resultMemRefType) {
2496  if (!resultMemRefType.getLayout().isIdentity())
2497  return emitOpError("result memref type should have identity affine map");
2498  if (shapeSize == ShapedType::kDynamic)
2499  return emitOpError("cannot use shape operand with dynamic length to "
2500  "reshape to statically-ranked memref type");
2501  if (shapeSize != resultMemRefType.getRank())
2502  return emitOpError(
2503  "length of shape operand differs from the result's memref rank");
2504  }
2505  return success();
2506 }
2507 
2508 //===----------------------------------------------------------------------===//
2509 // StoreOp
2510 //===----------------------------------------------------------------------===//
2511 
2513  if (getNumOperands() != 2 + getMemRefType().getRank())
2514  return emitOpError("store index operand count not equal to memref rank");
2515 
2516  return success();
2517 }
2518 
2519 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2520  SmallVectorImpl<OpFoldResult> &results) {
2521  /// store(memrefcast) -> store
2522  return foldMemRefCast(*this, getValueToStore());
2523 }
2524 
2525 //===----------------------------------------------------------------------===//
2526 // SubViewOp
2527 //===----------------------------------------------------------------------===//
2528 
2529 void SubViewOp::getAsmResultNames(
2530  function_ref<void(Value, StringRef)> setNameFn) {
2531  setNameFn(getResult(), "subview");
2532 }
2533 
2534 /// A subview result type can be fully inferred from the source type and the
2535 /// static representation of offsets, sizes and strides. Special sentinels
2536 /// encode the dynamic case.
2537 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2538  ArrayRef<int64_t> staticOffsets,
2539  ArrayRef<int64_t> staticSizes,
2540  ArrayRef<int64_t> staticStrides) {
2541  unsigned rank = sourceMemRefType.getRank();
2542  (void)rank;
2543  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2544  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2545  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2546 
2547  // Extract source offset and strides.
2548  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
2549 
2550  // Compute target offset whose value is:
2551  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2552  int64_t targetOffset = sourceOffset;
2553  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2554  auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2555  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2556  SaturatedInteger::wrap(staticOffset) *
2557  SaturatedInteger::wrap(targetStride))
2558  .asInteger();
2559  }
2560 
2561  // Compute target stride whose value is:
2562  // `sourceStrides_i * staticStrides_i`.
2563  SmallVector<int64_t, 4> targetStrides;
2564  targetStrides.reserve(staticOffsets.size());
2565  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2566  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2567  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2568  SaturatedInteger::wrap(staticStride))
2569  .asInteger());
2570  }
2571 
2572  // The type is now known.
2573  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2574  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2575  targetOffset, targetStrides),
2576  sourceMemRefType.getMemorySpace());
2577 }
2578 
2579 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2580  ArrayRef<OpFoldResult> offsets,
2581  ArrayRef<OpFoldResult> sizes,
2582  ArrayRef<OpFoldResult> strides) {
2583  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2584  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2585  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2586  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2587  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2588  if (!hasValidSizesOffsets(staticOffsets))
2589  return {};
2590  if (!hasValidSizesOffsets(staticSizes))
2591  return {};
2592  if (!hasValidStrides(staticStrides))
2593  return {};
2594  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2595  staticSizes, staticStrides);
2596 }
2597 
2598 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2599  MemRefType sourceRankedTensorType,
2600  ArrayRef<int64_t> offsets,
2601  ArrayRef<int64_t> sizes,
2602  ArrayRef<int64_t> strides) {
2603  auto inferredType = llvm::cast<MemRefType>(
2604  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2605  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2606  "expected ");
2607  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2608  return inferredType;
2609 
2610  // Compute which dimensions are dropped.
2611  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2612  computeRankReductionMask(inferredType.getShape(), resultShape);
2613  assert(dimsToProject.has_value() && "invalid rank reduction");
2614 
2615  // Compute the layout and result type.
2616  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2617  SmallVector<int64_t> rankReducedStrides;
2618  rankReducedStrides.reserve(resultShape.size());
2619  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2620  if (!dimsToProject->contains(idx))
2621  rankReducedStrides.push_back(value);
2622  }
2623  return MemRefType::get(resultShape, inferredType.getElementType(),
2624  StridedLayoutAttr::get(inferredLayout.getContext(),
2625  inferredLayout.getOffset(),
2626  rankReducedStrides),
2627  inferredType.getMemorySpace());
2628 }
2629 
2630 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2631  MemRefType sourceRankedTensorType,
2632  ArrayRef<OpFoldResult> offsets,
2633  ArrayRef<OpFoldResult> sizes,
2634  ArrayRef<OpFoldResult> strides) {
2635  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2636  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2637  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2638  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2639  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2640  return SubViewOp::inferRankReducedResultType(
2641  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2642  staticStrides);
2643 }
2644 
2645 // Build a SubViewOp with mixed static and dynamic entries and custom result
2646 // type. If the type passed is nullptr, it is inferred.
2647 void SubViewOp::build(OpBuilder &b, OperationState &result,
2648  MemRefType resultType, Value source,
2649  ArrayRef<OpFoldResult> offsets,
2650  ArrayRef<OpFoldResult> sizes,
2651  ArrayRef<OpFoldResult> strides,
2652  ArrayRef<NamedAttribute> attrs) {
2653  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2654  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2655  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2656  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2657  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2658  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2659  // Structuring implementation this way avoids duplication between builders.
2660  if (!resultType) {
2661  resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2662  sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2663  }
2664  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2665  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2666  b.getDenseI64ArrayAttr(staticSizes),
2667  b.getDenseI64ArrayAttr(staticStrides));
2668  result.addAttributes(attrs);
2669 }
2670 
2671 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2672 // type.
2673 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2674  ArrayRef<OpFoldResult> offsets,
2675  ArrayRef<OpFoldResult> sizes,
2676  ArrayRef<OpFoldResult> strides,
2677  ArrayRef<NamedAttribute> attrs) {
2678  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2679 }
2680 
2681 // Build a SubViewOp with static entries and inferred result type.
2682 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2683  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2684  ArrayRef<int64_t> strides,
2685  ArrayRef<NamedAttribute> attrs) {
2686  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2687  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2688  return b.getI64IntegerAttr(v);
2689  }));
2690  SmallVector<OpFoldResult> sizeValues =
2691  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2692  return b.getI64IntegerAttr(v);
2693  }));
2694  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2695  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2696  return b.getI64IntegerAttr(v);
2697  }));
2698  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2699 }
2700 
2701 // Build a SubViewOp with dynamic entries and custom result type. If the
2702 // type passed is nullptr, it is inferred.
2703 void SubViewOp::build(OpBuilder &b, OperationState &result,
2704  MemRefType resultType, Value source,
2705  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2706  ArrayRef<int64_t> strides,
2707  ArrayRef<NamedAttribute> attrs) {
2708  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2709  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2710  return b.getI64IntegerAttr(v);
2711  }));
2712  SmallVector<OpFoldResult> sizeValues =
2713  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2714  return b.getI64IntegerAttr(v);
2715  }));
2716  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2717  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2718  return b.getI64IntegerAttr(v);
2719  }));
2720  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2721  attrs);
2722 }
2723 
2724 // Build a SubViewOp with dynamic entries and custom result type. If the type
2725 // passed is nullptr, it is inferred.
2726 void SubViewOp::build(OpBuilder &b, OperationState &result,
2727  MemRefType resultType, Value source, ValueRange offsets,
2728  ValueRange sizes, ValueRange strides,
2729  ArrayRef<NamedAttribute> attrs) {
2730  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2731  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2732  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2733  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2734  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2735  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2736  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2737 }
2738 
2739 // Build a SubViewOp with dynamic entries and inferred result type.
2740 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2741  ValueRange offsets, ValueRange sizes, ValueRange strides,
2742  ArrayRef<NamedAttribute> attrs) {
2743  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2744 }
2745 
2746 /// For ViewLikeOpInterface.
2747 Value SubViewOp::getViewSource() { return getSource(); }
2748 
2749 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2750 /// static value).
2751 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2752  int64_t t1Offset, t2Offset;
2753  SmallVector<int64_t> t1Strides, t2Strides;
2754  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2755  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2756  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2757 }
2758 
2759 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2760 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2761 /// marked as dropped in `droppedDims`.
2762 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2763  const llvm::SmallBitVector &droppedDims) {
2764  assert(size_t(t1.getRank()) == droppedDims.size() && "incorrect number of bits");
2765  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2766  "incorrect number of dropped dims");
2767  int64_t t1Offset, t2Offset;
2768  SmallVector<int64_t> t1Strides, t2Strides;
2769  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2770  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2771  if (failed(res1) || failed(res2))
2772  return false;
2773  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2774  if (droppedDims[i])
2775  continue;
2776  if (t1Strides[i] != t2Strides[j])
2777  return false;
2778  ++j;
2779  }
2780  return true;
2781 }
2782 
2784  Operation *op, Type expectedType) {
2785  auto memrefType = llvm::cast<ShapedType>(expectedType);
2786  switch (result) {
2788  return success();
2790  return op->emitError("expected result rank to be smaller or equal to ")
2791  << "the source rank. ";
2793  return op->emitError("expected result type to be ")
2794  << expectedType
2795  << " or a rank-reduced version. (mismatch of result sizes) ";
2797  return op->emitError("expected result element type to be ")
2798  << memrefType.getElementType();
2800  return op->emitError("expected result and source memory spaces to match.");
2802  return op->emitError("expected result type to be ")
2803  << expectedType
2804  << " or a rank-reduced version. (mismatch of result layout) ";
2805  }
2806  llvm_unreachable("unexpected subview verification result");
2807 }
2808 
2809 /// Verifier for SubViewOp.
2811  MemRefType baseType = getSourceType();
2812  MemRefType subViewType = getType();
2813 
2814  // The base memref and the view memref should be in the same memory space.
2815  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2816  return emitError("different memory spaces specified for base memref "
2817  "type ")
2818  << baseType << " and subview memref type " << subViewType;
2819 
2820  // Verify that the base memref type has a strided layout map.
2821  if (!isStrided(baseType))
2822  return emitError("base type ") << baseType << " is not strided";
2823 
2824  // Compute the expected result type, assuming that there are no rank
2825  // reductions.
2826  auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2827  baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2828 
2829  // Verify all properties of a shaped type: rank, element type and dimension
2830  // sizes. This takes into account potential rank reductions.
2831  auto shapedTypeVerification = isRankReducedType(
2832  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2833  if (shapedTypeVerification != SliceVerificationResult::Success)
2834  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2835 
2836  // Make sure that the memory space did not change.
2837  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2839  *this, expectedType);
2840 
2841  // Verify the offset of the layout map.
2842  if (!haveCompatibleOffsets(expectedType, subViewType))
2844  *this, expectedType);
2845 
2846  // The only thing that's left to verify now are the strides. First, compute
2847  // the unused dimensions due to rank reductions. We have to look at sizes and
2848  // strides to decide which dimensions were dropped. This function also
2849  // partially verifies strides in case of rank reductions.
2850  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
2851  getMixedSizes());
2852  if (failed(unusedDims))
2854  *this, expectedType);
2855 
2856  // Strides must match.
2857  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
2859  *this, expectedType);
2860 
2861  return success();
2862 }
2863 
2864 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2865  return os << "range " << range.offset << ":" << range.size << ":"
2866  << range.stride;
2867 }
2868 
2869 /// Return the list of Range (i.e. offset, size, stride). Each Range
2870 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2871 /// with `b` at location `loc`.
2872 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2873  OpBuilder &b, Location loc) {
2874  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2875  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2876  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2878  unsigned rank = ranks[0];
2879  res.reserve(rank);
2880  for (unsigned idx = 0; idx < rank; ++idx) {
2881  Value offset =
2882  op.isDynamicOffset(idx)
2883  ? op.getDynamicOffset(idx)
2884  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
2885  Value size =
2886  op.isDynamicSize(idx)
2887  ? op.getDynamicSize(idx)
2888  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
2889  Value stride =
2890  op.isDynamicStride(idx)
2891  ? op.getDynamicStride(idx)
2892  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
2893  res.emplace_back(Range{offset, size, stride});
2894  }
2895  return res;
2896 }
2897 
2898 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2899 /// to deduce the result type for the given `sourceType`. Additionally, reduce
2900 /// the rank of the inferred result type if `currentResultType` is lower rank
2901 /// than `currentSourceType`. Use this signature if `sourceType` is updated
2902 /// together with the result type. In this case, it is important to compute
2903 /// the dropped dimensions using `currentSourceType` whose strides align with
2904 /// `currentResultType`.
2906  MemRefType currentResultType, MemRefType currentSourceType,
2907  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2908  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2909  auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2910  sourceType, mixedOffsets, mixedSizes, mixedStrides));
2912  currentSourceType, currentResultType, mixedSizes);
2913  if (failed(unusedDims))
2914  return nullptr;
2915 
2916  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
2917  SmallVector<int64_t> shape, strides;
2918  unsigned numDimsAfterReduction =
2919  nonRankReducedType.getRank() - unusedDims->count();
2920  shape.reserve(numDimsAfterReduction);
2921  strides.reserve(numDimsAfterReduction);
2922  for (const auto &[idx, size, stride] :
2923  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
2924  nonRankReducedType.getShape(), layout.getStrides())) {
2925  if (unusedDims->test(idx))
2926  continue;
2927  shape.push_back(size);
2928  strides.push_back(stride);
2929  }
2930 
2931  return MemRefType::get(shape, nonRankReducedType.getElementType(),
2932  StridedLayoutAttr::get(sourceType.getContext(),
2933  layout.getOffset(), strides),
2934  nonRankReducedType.getMemorySpace());
2935 }
2936 
2938  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
2939  auto memrefType = llvm::cast<MemRefType>(memref.getType());
2940  unsigned rank = memrefType.getRank();
2941  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2942  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
2943  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2944  auto targetType =
2945  llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
2946  targetShape, memrefType, offsets, sizes, strides));
2947  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
2948  sizes, strides);
2949 }
2950 
2951 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
2952  Value value,
2953  ArrayRef<int64_t> desiredShape) {
2954  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
2955  assert(sourceMemrefType && "not a ranked memref type");
2956  auto sourceShape = sourceMemrefType.getShape();
2957  if (sourceShape.equals(desiredShape))
2958  return value;
2959  auto maybeRankReductionMask =
2960  mlir::computeRankReductionMask(sourceShape, desiredShape);
2961  if (!maybeRankReductionMask)
2962  return failure();
2963  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
2964 }
2965 
2966 /// Helper method to check if a `subview` operation is trivially a no-op. This
2967 /// is the case if the all offsets are zero, all strides are 1, and the source
2968 /// shape is same as the size of the subview. In such cases, the subview can
2969 /// be folded into its source.
2970 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
2971  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2972  return false;
2973 
2974  auto mixedOffsets = subViewOp.getMixedOffsets();
2975  auto mixedSizes = subViewOp.getMixedSizes();
2976  auto mixedStrides = subViewOp.getMixedStrides();
2977 
2978  // Check offsets are zero.
2979  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
2980  std::optional<int64_t> intValue = getConstantIntValue(ofr);
2981  return !intValue || intValue.value() != 0;
2982  }))
2983  return false;
2984 
2985  // Check strides are one.
2986  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
2987  std::optional<int64_t> intValue = getConstantIntValue(ofr);
2988  return !intValue || intValue.value() != 1;
2989  }))
2990  return false;
2991 
2992  // Check all size values are static and matches the (static) source shape.
2993  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
2994  for (const auto &size : llvm::enumerate(mixedSizes)) {
2995  std::optional<int64_t> intValue = getConstantIntValue(size.value());
2996  if (!intValue || *intValue != sourceShape[size.index()])
2997  return false;
2998  }
2999  // All conditions met. The `SubViewOp` is foldable as a no-op.
3000  return true;
3001 }
3002 
3003 namespace {
3004 /// Pattern to rewrite a subview op with MemRefCast arguments.
3005 /// This essentially pushes memref.cast past its consuming subview when
3006 /// `canFoldIntoConsumerOp` is true.
3007 ///
3008 /// Example:
3009 /// ```
3010 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3011 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3012 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3013 /// ```
3014 /// is rewritten into:
3015 /// ```
3016 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3017 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3018 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3019 /// ```
3020 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3021 public:
3023 
3024  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3025  PatternRewriter &rewriter) const override {
3026  // Any constant operand, just return to let SubViewOpConstantFolder kick
3027  // in.
3028  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3029  return matchPattern(operand, matchConstantIndex());
3030  }))
3031  return failure();
3032 
3033  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3034  if (!castOp)
3035  return failure();
3036 
3037  if (!CastOp::canFoldIntoConsumerOp(castOp))
3038  return failure();
3039 
3040  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3041  // the MemRefCastOp source operand type to infer the result type and the
3042  // current SubViewOp source operand type to compute the dropped dimensions
3043  // if the operation is rank-reducing.
3044  auto resultType = getCanonicalSubViewResultType(
3045  subViewOp.getType(), subViewOp.getSourceType(),
3046  llvm::cast<MemRefType>(castOp.getSource().getType()),
3047  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3048  subViewOp.getMixedStrides());
3049  if (!resultType)
3050  return failure();
3051 
3052  Value newSubView = rewriter.create<SubViewOp>(
3053  subViewOp.getLoc(), resultType, castOp.getSource(),
3054  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3055  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3056  subViewOp.getStaticStrides());
3057  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3058  newSubView);
3059  return success();
3060  }
3061 };
3062 
3063 /// Canonicalize subview ops that are no-ops. When the source shape is not
3064 /// same as a result shape due to use of `affine_map`.
3065 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3066 public:
3068 
3069  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3070  PatternRewriter &rewriter) const override {
3071  if (!isTrivialSubViewOp(subViewOp))
3072  return failure();
3073  if (subViewOp.getSourceType() == subViewOp.getType()) {
3074  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3075  return success();
3076  }
3077  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3078  subViewOp.getSource());
3079  return success();
3080  }
3081 };
3082 } // namespace
3083 
3084 /// Return the canonical type of the result of a subview.
3086  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3087  ArrayRef<OpFoldResult> mixedSizes,
3088  ArrayRef<OpFoldResult> mixedStrides) {
3089  // Infer a memref type without taking into account any rank reductions.
3090  auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3091  mixedSizes, mixedStrides);
3092  if (!resTy)
3093  return {};
3094  MemRefType nonReducedType = cast<MemRefType>(resTy);
3095 
3096  // Directly return the non-rank reduced type if there are no dropped dims.
3097  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3098  if (droppedDims.none())
3099  return nonReducedType;
3100 
3101  // Take the strides and offset from the non-rank reduced type.
3102  auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
3103 
3104  // Drop dims from shape and strides.
3105  SmallVector<int64_t> targetShape;
3106  SmallVector<int64_t> targetStrides;
3107  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3108  if (droppedDims.test(i))
3109  continue;
3110  targetStrides.push_back(nonReducedStrides[i]);
3111  targetShape.push_back(nonReducedType.getDimSize(i));
3112  }
3113 
3114  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3115  StridedLayoutAttr::get(nonReducedType.getContext(),
3116  offset, targetStrides),
3117  nonReducedType.getMemorySpace());
3118  }
3119 };
3120 
3121 /// A canonicalizer wrapper to replace SubViewOps.
3123  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3124  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3125  }
3126 };
3127 
3128 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3129  MLIRContext *context) {
3130  results
3133  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3134 }
3135 
3136 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3137  auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
3138  auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
3139 
3140  if (resultShapedType.hasStaticShape() &&
3141  resultShapedType == sourceShapedType) {
3142  return getViewSource();
3143  }
3144 
3145  // Fold subview(subview(x)), where both subviews have the same size and the
3146  // second subview's offsets are all zero. (I.e., the second subview is a
3147  // no-op.)
3148  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3149  auto srcSizes = srcSubview.getMixedSizes();
3150  auto sizes = getMixedSizes();
3151  auto offsets = getMixedOffsets();
3152  bool allOffsetsZero = llvm::all_of(
3153  offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3154  auto strides = getMixedStrides();
3155  bool allStridesOne = llvm::all_of(
3156  strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3157  bool allSizesSame = llvm::equal(sizes, srcSizes);
3158  if (allOffsetsZero && allStridesOne && allSizesSame &&
3159  resultShapedType == sourceShapedType)
3160  return getViewSource();
3161  }
3162 
3163  return {};
3164 }
3165 
3166 //===----------------------------------------------------------------------===//
3167 // TransposeOp
3168 //===----------------------------------------------------------------------===//
3169 
3170 void TransposeOp::getAsmResultNames(
3171  function_ref<void(Value, StringRef)> setNameFn) {
3172  setNameFn(getResult(), "transpose");
3173 }
3174 
3175 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3176 static MemRefType inferTransposeResultType(MemRefType memRefType,
3177  AffineMap permutationMap) {
3178  auto originalSizes = memRefType.getShape();
3179  auto [originalStrides, offset] = getStridesAndOffset(memRefType);
3180  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3181 
3182  // Compute permuted sizes and strides.
3183  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3184  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3185 
3186  return MemRefType::Builder(memRefType)
3187  .setShape(sizes)
3188  .setLayout(
3189  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3190 }
3191 
3192 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3193  AffineMapAttr permutation,
3194  ArrayRef<NamedAttribute> attrs) {
3195  auto permutationMap = permutation.getValue();
3196  assert(permutationMap);
3197 
3198  auto memRefType = llvm::cast<MemRefType>(in.getType());
3199  // Compute result type.
3200  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3201 
3202  build(b, result, resultType, in, attrs);
3203  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3204 }
3205 
3206 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3208  p << " " << getIn() << " " << getPermutation();
3209  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3210  p << " : " << getIn().getType() << " to " << getType();
3211 }
3212 
3215  AffineMap permutation;
3216  MemRefType srcType, dstType;
3217  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3218  parser.parseOptionalAttrDict(result.attributes) ||
3219  parser.parseColonType(srcType) ||
3220  parser.resolveOperand(in, srcType, result.operands) ||
3221  parser.parseKeywordType("to", dstType) ||
3222  parser.addTypeToList(dstType, result.types))
3223  return failure();
3224 
3225  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3226  AffineMapAttr::get(permutation));
3227  return success();
3228 }
3229 
3231  if (!getPermutation().isPermutation())
3232  return emitOpError("expected a permutation map");
3233  if (getPermutation().getNumDims() != getIn().getType().getRank())
3234  return emitOpError("expected a permutation map of same rank as the input");
3235 
3236  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3237  auto resultType = llvm::cast<MemRefType>(getType());
3238  auto canonicalResultType = canonicalizeStridedLayout(
3239  inferTransposeResultType(srcType, getPermutation()));
3240 
3241  if (canonicalizeStridedLayout(resultType) != canonicalResultType)
3242  return emitOpError("result type ")
3243  << resultType
3244  << " is not equivalent to the canonical transposed input type "
3245  << canonicalResultType;
3246  return success();
3247 }
3248 
3249 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3250  // First check for identity permutation, we can fold it away if input and
3251  // result types are identical already.
3252  if (getPermutation().isIdentity() && getType() == getIn().getType())
3253  return getIn();
3254  // Fold two consecutive memref.transpose Ops into one by composing their
3255  // permutation maps.
3256  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3257  AffineMap composedPermutation =
3258  getPermutation().compose(otherTransposeOp.getPermutation());
3259  getInMutable().assign(otherTransposeOp.getIn());
3260  setPermutation(composedPermutation);
3261  return getResult();
3262  }
3263  return {};
3264 }
3265 
3266 //===----------------------------------------------------------------------===//
3267 // ViewOp
3268 //===----------------------------------------------------------------------===//
3269 
3270 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3271  setNameFn(getResult(), "view");
3272 }
3273 
3275  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3276  auto viewType = getType();
3277 
3278  // The base memref should have identity layout map (or none).
3279  if (!baseType.getLayout().isIdentity())
3280  return emitError("unsupported map for base memref type ") << baseType;
3281 
3282  // The result memref should have identity layout map (or none).
3283  if (!viewType.getLayout().isIdentity())
3284  return emitError("unsupported map for result memref type ") << viewType;
3285 
3286  // The base memref and the view memref should be in the same memory space.
3287  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3288  return emitError("different memory spaces specified for base memref "
3289  "type ")
3290  << baseType << " and view memref type " << viewType;
3291 
3292  // Verify that we have the correct number of sizes for the result type.
3293  unsigned numDynamicDims = viewType.getNumDynamicDims();
3294  if (getSizes().size() != numDynamicDims)
3295  return emitError("incorrect number of size operands for type ") << viewType;
3296 
3297  return success();
3298 }
3299 
3300 Value ViewOp::getViewSource() { return getSource(); }
3301 
3302 namespace {
3303 
3304 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3306 
3307  LogicalResult matchAndRewrite(ViewOp viewOp,
3308  PatternRewriter &rewriter) const override {
3309  // Return if none of the operands are constants.
3310  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3311  return matchPattern(operand, matchConstantIndex());
3312  }))
3313  return failure();
3314 
3315  // Get result memref type.
3316  auto memrefType = viewOp.getType();
3317 
3318  // Get offset from old memref view type 'memRefType'.
3319  int64_t oldOffset;
3320  SmallVector<int64_t, 4> oldStrides;
3321  if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3322  return failure();
3323  assert(oldOffset == 0 && "Expected 0 offset");
3324 
3325  SmallVector<Value, 4> newOperands;
3326 
3327  // Offset cannot be folded into result type.
3328 
3329  // Fold any dynamic dim operands which are produced by a constant.
3330  SmallVector<int64_t, 4> newShapeConstants;
3331  newShapeConstants.reserve(memrefType.getRank());
3332 
3333  unsigned dynamicDimPos = 0;
3334  unsigned rank = memrefType.getRank();
3335  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3336  int64_t dimSize = memrefType.getDimSize(dim);
3337  // If this is already static dimension, keep it.
3338  if (!ShapedType::isDynamic(dimSize)) {
3339  newShapeConstants.push_back(dimSize);
3340  continue;
3341  }
3342  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3343  if (auto constantIndexOp =
3344  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3345  // Dynamic shape dimension will be folded.
3346  newShapeConstants.push_back(constantIndexOp.value());
3347  } else {
3348  // Dynamic shape dimension not folded; copy operand from old memref.
3349  newShapeConstants.push_back(dimSize);
3350  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3351  }
3352  dynamicDimPos++;
3353  }
3354 
3355  // Create new memref type with constant folded dims.
3356  MemRefType newMemRefType =
3357  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3358  // Nothing new, don't fold.
3359  if (newMemRefType == memrefType)
3360  return failure();
3361 
3362  // Create new ViewOp.
3363  auto newViewOp = rewriter.create<ViewOp>(
3364  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3365  viewOp.getByteShift(), newOperands);
3366  // Insert a cast so we have the same type as the old memref type.
3367  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3368  return success();
3369  }
3370 };
3371 
3372 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3374 
3375  LogicalResult matchAndRewrite(ViewOp viewOp,
3376  PatternRewriter &rewriter) const override {
3377  Value memrefOperand = viewOp.getOperand(0);
3378  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3379  if (!memrefCastOp)
3380  return failure();
3381  Value allocOperand = memrefCastOp.getOperand();
3382  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3383  if (!allocOp)
3384  return failure();
3385  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3386  viewOp.getByteShift(),
3387  viewOp.getSizes());
3388  return success();
3389  }
3390 };
3391 
3392 } // namespace
3393 
3394 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3395  MLIRContext *context) {
3396  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3397 }
3398 
3399 //===----------------------------------------------------------------------===//
3400 // AtomicRMWOp
3401 //===----------------------------------------------------------------------===//
3402 
3404  if (getMemRefType().getRank() != getNumOperands() - 2)
3405  return emitOpError(
3406  "expects the number of subscripts to be equal to memref rank");
3407  switch (getKind()) {
3408  case arith::AtomicRMWKind::addf:
3409  case arith::AtomicRMWKind::maximumf:
3410  case arith::AtomicRMWKind::minimumf:
3411  case arith::AtomicRMWKind::mulf:
3412  if (!llvm::isa<FloatType>(getValue().getType()))
3413  return emitOpError() << "with kind '"
3414  << arith::stringifyAtomicRMWKind(getKind())
3415  << "' expects a floating-point type";
3416  break;
3417  case arith::AtomicRMWKind::addi:
3418  case arith::AtomicRMWKind::maxs:
3419  case arith::AtomicRMWKind::maxu:
3420  case arith::AtomicRMWKind::mins:
3421  case arith::AtomicRMWKind::minu:
3422  case arith::AtomicRMWKind::muli:
3423  case arith::AtomicRMWKind::ori:
3424  case arith::AtomicRMWKind::andi:
3425  if (!llvm::isa<IntegerType>(getValue().getType()))
3426  return emitOpError() << "with kind '"
3427  << arith::stringifyAtomicRMWKind(getKind())
3428  << "' expects an integer type";
3429  break;
3430  default:
3431  break;
3432  }
3433  return success();
3434 }
3435 
3436 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3437  /// atomicrmw(memrefcast) -> atomicrmw
3438  if (succeeded(foldMemRefCast(*this, getValue())))
3439  return getResult();
3440  return OpFoldResult();
3441 }
3442 
3443 //===----------------------------------------------------------------------===//
3444 // TableGen'd op method definitions
3445 //===----------------------------------------------------------------------===//
3446 
3447 #define GET_OP_CLASSES
3448 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:71
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
Definition: MemRefOps.cpp:163
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
Definition: MemRefOps.cpp:115
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1504
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2058
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:451
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3176
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:432
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2751
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1351
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:2905
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
Definition: MemRefOps.cpp:2783
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
Definition: MemRefOps.cpp:176
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1518
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
Definition: MemRefOps.cpp:155
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:921
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:2970
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:474
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2762
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:906
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2151
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2267
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:201
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1508
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:131
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:542
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
IndexType getIndexType()
Definition: Builders.cpp:71
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:201
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:222
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:212
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:227
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:631
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:537
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< Operation::operand_range > getIndices(Operation *op)
Get and set the indices that the given load/store operation is operating on.
Definition: Utils.cpp:10
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:59
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:67
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:2937
MPInt getIndex(ConeV cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:64
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:361
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:311
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:369
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:2872
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:521
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:524
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:481
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:484
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2426
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3122
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3123
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3085
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3086
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.