MLIR  18.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 namespace {
30 /// Idiomatic saturated operations on offsets, sizes and strides.
31 namespace saturated_arith {
32 struct Wrapper {
33  static Wrapper stride(int64_t v) {
34  return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
35  }
36  static Wrapper offset(int64_t v) {
37  return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
38  }
39  static Wrapper size(int64_t v) {
40  return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
41  }
42  int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
43  int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
44  int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
45  bool operator==(Wrapper other) {
46  return (saturated && other.saturated) ||
47  (!saturated && !other.saturated && v == other.v);
48  }
49  bool operator!=(Wrapper other) { return !(*this == other); }
50  Wrapper operator+(Wrapper other) {
51  if (saturated || other.saturated)
52  return Wrapper{true, 0};
53  return Wrapper{false, other.v + v};
54  }
55  Wrapper operator*(Wrapper other) {
56  if (saturated || other.saturated)
57  return Wrapper{true, 0};
58  return Wrapper{false, other.v * v};
59  }
60  bool saturated;
61  int64_t v;
62 };
63 } // namespace saturated_arith
64 } // namespace
65 
66 /// Materialize a single constant operation from a given attribute value with
67 /// the desired resultant type.
69  Attribute value, Type type,
70  Location loc) {
71  return arith::ConstantOp::materialize(builder, value, type, loc);
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Common canonicalization pattern support logic
76 //===----------------------------------------------------------------------===//
77 
78 /// This is a common class used for patterns of the form
79 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
80 /// into the root operation directly.
82  bool folded = false;
83  for (OpOperand &operand : op->getOpOperands()) {
84  auto cast = operand.get().getDefiningOp<CastOp>();
85  if (cast && operand.get() != inner &&
86  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
87  operand.set(cast.getOperand());
88  folded = true;
89  }
90  }
91  return success(folded);
92 }
93 
94 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
95 /// type.
97  if (auto memref = llvm::dyn_cast<MemRefType>(type))
98  return RankedTensorType::get(memref.getShape(), memref.getElementType());
99  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
100  return UnrankedTensorType::get(memref.getElementType());
101  return NoneType::get(type.getContext());
102 }
103 
105  int64_t dim) {
106  auto memrefType = llvm::cast<MemRefType>(value.getType());
108  if (memrefType.isDynamicDim(dim))
109  return builder.createOrFold<memref::DimOp>(loc, value, dim);
110 
111  return builder.getIndexAttr(memrefType.getDimSize(dim));
112 }
113 
115  Location loc, Value value) {
116  auto memrefType = llvm::cast<MemRefType>(value.getType());
118  for (int64_t i = 0; i < memrefType.getRank(); ++i)
119  result.push_back(getMixedSize(builder, loc, value, i));
120  return result;
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // Utility functions for propagating static information
125 //===----------------------------------------------------------------------===//
126 
127 /// Helper function that infers the constant values from a list of \p values,
128 /// a \p memRefTy, and another helper function \p getAttributes.
129 /// The inferred constant values replace the related `OpFoldResult` in
130 /// \p values.
131 ///
132 /// \note This function shouldn't be used directly, instead, use the
133 /// `getConstifiedMixedXXX` methods from the related operations.
134 ///
135 /// \p getAttributes retuns a list of potentially constant values, as determined
136 /// by \p isDynamic, from the given \p memRefTy. The returned list must have as
137 /// many elements as \p values or be empty.
138 ///
139 /// E.g., consider the following example:
140 /// ```
141 /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
142 /// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
143 /// ```
144 /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
145 /// Now using this helper function with:
146 /// - `values == [2, %dyn_stride]`,
147 /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
148 /// - `getAttributes == getConstantStrides` (i.e., a wrapper around
149 /// `getStridesAndOffset`), and
150 /// - `isDynamic == ShapedType::isDynamic`
151 /// Will yield: `values == [2, 1]`
153  SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
154  MLIRContext *ctxt,
155  llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
156  llvm::function_ref<bool(int64_t)> isDynamic) {
157  SmallVector<int64_t> constValues = getAttributes(memRefTy);
158  Builder builder(ctxt);
159  for (const auto &it : llvm::enumerate(constValues)) {
160  int64_t constValue = it.value();
161  if (!isDynamic(constValue))
162  values[it.index()] = builder.getIndexAttr(constValue);
163  }
164  for (OpFoldResult &ofr : values) {
165  if (ofr.is<Attribute>()) {
166  // FIXME: We shouldn't need to do that, but right now, the static indices
167  // are created with the wrong type: `i64` instead of `index`.
168  // As a result, if we were to keep the attribute as is, we may fail to see
169  // that two attributes are equal because one would have the i64 type and
170  // the other the index type.
171  // The alternative would be to create constant indices with getI64Attr in
172  // this and the previous loop, but it doesn't logically make sense (we are
173  // dealing with indices here) and would only strenghten the inconsistency
174  // around how static indices are created (some places use getI64Attr,
175  // others use getIndexAttr).
176  // The workaround here is to stick to the IndexAttr type for all the
177  // values, hence we recreate the attribute even when it is already static
178  // to make sure the type is consistent.
179  ofr = builder.getIndexAttr(
180  llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt());
181  continue;
182  }
183  std::optional<int64_t> maybeConstant =
184  getConstantIntValue(ofr.get<Value>());
185  if (maybeConstant)
186  ofr = builder.getIndexAttr(*maybeConstant);
187  }
188 }
189 
190 /// Wrapper around `getShape` that conforms to the function signature
191 /// expected for `getAttributes` in `constifyIndexValues`.
192 static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
193  ArrayRef<int64_t> sizes = memRefTy.getShape();
194  return SmallVector<int64_t>(sizes.begin(), sizes.end());
195 }
196 
197 /// Wrapper around `getStridesAndOffset` that returns only the offset and
198 /// conforms to the function signature expected for `getAttributes` in
199 /// `constifyIndexValues`.
200 static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
201  SmallVector<int64_t> strides;
202  int64_t offset;
203  LogicalResult hasStaticInformation =
204  getStridesAndOffset(memrefType, strides, offset);
205  if (failed(hasStaticInformation))
206  return SmallVector<int64_t>();
207  return SmallVector<int64_t>(1, offset);
208 }
209 
210 /// Wrapper around `getStridesAndOffset` that returns only the strides and
211 /// conforms to the function signature expected for `getAttributes` in
212 /// `constifyIndexValues`.
213 static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
214  SmallVector<int64_t> strides;
215  int64_t offset;
216  LogicalResult hasStaticInformation =
217  getStridesAndOffset(memrefType, strides, offset);
218  if (failed(hasStaticInformation))
219  return SmallVector<int64_t>();
220  return strides;
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // AllocOp / AllocaOp
225 //===----------------------------------------------------------------------===//
226 
227 void AllocOp::getAsmResultNames(
228  function_ref<void(Value, StringRef)> setNameFn) {
229  setNameFn(getResult(), "alloc");
230 }
231 
232 void AllocaOp::getAsmResultNames(
233  function_ref<void(Value, StringRef)> setNameFn) {
234  setNameFn(getResult(), "alloca");
235 }
236 
237 template <typename AllocLikeOp>
238 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
239  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
240  "applies to only alloc or alloca");
241  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
242  if (!memRefType)
243  return op.emitOpError("result must be a memref");
244 
245  if (static_cast<int64_t>(op.getDynamicSizes().size()) !=
246  memRefType.getNumDynamicDims())
247  return op.emitOpError("dimension operand count does not equal memref "
248  "dynamic dimension count");
249 
250  unsigned numSymbols = 0;
251  if (!memRefType.getLayout().isIdentity())
252  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
253  if (op.getSymbolOperands().size() != numSymbols)
254  return op.emitOpError("symbol operand count does not equal memref symbol "
255  "count: expected ")
256  << numSymbols << ", got " << op.getSymbolOperands().size();
257 
258  return success();
259 }
260 
262 
264  // An alloca op needs to have an ancestor with an allocation scope trait.
265  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
266  return emitOpError(
267  "requires an ancestor op with AutomaticAllocationScope trait");
268 
269  return verifyAllocLikeOp(*this);
270 }
271 
272 namespace {
273 /// Fold constant dimensions into an alloc like operation.
274 template <typename AllocLikeOp>
275 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
277 
278  LogicalResult matchAndRewrite(AllocLikeOp alloc,
279  PatternRewriter &rewriter) const override {
280  // Check to see if any dimensions operands are constants. If so, we can
281  // substitute and drop them.
282  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
283  APInt constSizeArg;
284  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
285  return false;
286  return constSizeArg.isNonNegative();
287  }))
288  return failure();
289 
290  auto memrefType = alloc.getType();
291 
292  // Ok, we have one or more constant operands. Collect the non-constant ones
293  // and keep track of the resultant memref type to build.
294  SmallVector<int64_t, 4> newShapeConstants;
295  newShapeConstants.reserve(memrefType.getRank());
296  SmallVector<Value, 4> dynamicSizes;
297 
298  unsigned dynamicDimPos = 0;
299  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
300  int64_t dimSize = memrefType.getDimSize(dim);
301  // If this is already static dimension, keep it.
302  if (!ShapedType::isDynamic(dimSize)) {
303  newShapeConstants.push_back(dimSize);
304  continue;
305  }
306  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
307  APInt constSizeArg;
308  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
309  constSizeArg.isNonNegative()) {
310  // Dynamic shape dimension will be folded.
311  newShapeConstants.push_back(constSizeArg.getZExtValue());
312  } else {
313  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
314  newShapeConstants.push_back(ShapedType::kDynamic);
315  dynamicSizes.push_back(dynamicSize);
316  }
317  dynamicDimPos++;
318  }
319 
320  // Create new memref type (which will have fewer dynamic dimensions).
321  MemRefType newMemRefType =
322  MemRefType::Builder(memrefType).setShape(newShapeConstants);
323  assert(static_cast<int64_t>(dynamicSizes.size()) ==
324  newMemRefType.getNumDynamicDims());
325 
326  // Create and insert the alloc op for the new memref.
327  auto newAlloc = rewriter.create<AllocLikeOp>(
328  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
329  alloc.getAlignmentAttr());
330  // Insert a cast so we have the same type as the old alloc.
331  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
332  return success();
333  }
334 };
335 
336 /// Fold alloc operations with no users or only store and dealloc uses.
337 template <typename T>
338 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
340 
341  LogicalResult matchAndRewrite(T alloc,
342  PatternRewriter &rewriter) const override {
343  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
344  if (auto storeOp = dyn_cast<StoreOp>(op))
345  return storeOp.getValue() == alloc;
346  return !isa<DeallocOp>(op);
347  }))
348  return failure();
349 
350  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
351  rewriter.eraseOp(user);
352 
353  rewriter.eraseOp(alloc);
354  return success();
355  }
356 };
357 } // namespace
358 
359 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
360  MLIRContext *context) {
361  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
362 }
363 
364 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
365  MLIRContext *context) {
366  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
367  context);
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // ReallocOp
372 //===----------------------------------------------------------------------===//
373 
375  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
376  MemRefType resultType = getType();
377 
378  // The source memref should have identity layout (or none).
379  if (!sourceType.getLayout().isIdentity())
380  return emitError("unsupported layout for source memref type ")
381  << sourceType;
382 
383  // The result memref should have identity layout (or none).
384  if (!resultType.getLayout().isIdentity())
385  return emitError("unsupported layout for result memref type ")
386  << resultType;
387 
388  // The source memref and the result memref should be in the same memory space.
389  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
390  return emitError("different memory spaces specified for source memref "
391  "type ")
392  << sourceType << " and result memref type " << resultType;
393 
394  // The source memref and the result memref should have the same element type.
395  if (sourceType.getElementType() != resultType.getElementType())
396  return emitError("different element types specified for source memref "
397  "type ")
398  << sourceType << " and result memref type " << resultType;
399 
400  // Verify that we have the dynamic dimension operand when it is needed.
401  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
402  return emitError("missing dimension operand for result type ")
403  << resultType;
404  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
405  return emitError("unnecessary dimension operand for result type ")
406  << resultType;
407 
408  return success();
409 }
410 
411 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
412  MLIRContext *context) {
413  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // AllocaScopeOp
418 //===----------------------------------------------------------------------===//
419 
421  bool printBlockTerminators = false;
422 
423  p << ' ';
424  if (!getResults().empty()) {
425  p << " -> (" << getResultTypes() << ")";
426  printBlockTerminators = true;
427  }
428  p << ' ';
429  p.printRegion(getBodyRegion(),
430  /*printEntryBlockArgs=*/false,
431  /*printBlockTerminators=*/printBlockTerminators);
432  p.printOptionalAttrDict((*this)->getAttrs());
433 }
434 
435 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
436  // Create a region for the body.
437  result.regions.reserve(1);
438  Region *bodyRegion = result.addRegion();
439 
440  // Parse optional results type list.
441  if (parser.parseOptionalArrowTypeList(result.types))
442  return failure();
443 
444  // Parse the body region.
445  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
446  return failure();
447  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
448  result.location);
449 
450  // Parse the optional attribute list.
451  if (parser.parseOptionalAttrDict(result.attributes))
452  return failure();
453 
454  return success();
455 }
456 
457 void AllocaScopeOp::getSuccessorRegions(
459  if (!point.isParent()) {
460  regions.push_back(RegionSuccessor(getResults()));
461  return;
462  }
463 
464  regions.push_back(RegionSuccessor(&getBodyRegion()));
465 }
466 
467 /// Given an operation, return whether this op is guaranteed to
468 /// allocate an AutomaticAllocationScopeResource
470  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
471  if (!interface)
472  return false;
473  for (auto res : op->getResults()) {
474  if (auto effect =
475  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
476  if (isa<SideEffects::AutomaticAllocationScopeResource>(
477  effect->getResource()))
478  return true;
479  }
480  }
481  return false;
482 }
483 
484 /// Given an operation, return whether this op itself could
485 /// allocate an AutomaticAllocationScopeResource. Note that
486 /// this will not check whether an operation contained within
487 /// the op can allocate.
489  // This op itself doesn't create a stack allocation,
490  // the inner allocation should be handled separately.
492  return false;
493  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
494  if (!interface)
495  return true;
496  for (auto res : op->getResults()) {
497  if (auto effect =
498  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
499  if (isa<SideEffects::AutomaticAllocationScopeResource>(
500  effect->getResource()))
501  return true;
502  }
503  }
504  return false;
505 }
506 
507 /// Return whether this op is the last non terminating op
508 /// in a region. That is to say, it is in a one-block region
509 /// and is only followed by a terminator. This prevents
510 /// extending the lifetime of allocations.
512  return op->getNextNode() == op->getBlock()->getTerminator() &&
513  op->getParentRegion()->getBlocks().size() == 1;
514 }
515 
516 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
517 /// or it contains no allocation.
518 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
520 
521  LogicalResult matchAndRewrite(AllocaScopeOp op,
522  PatternRewriter &rewriter) const override {
523  bool hasPotentialAlloca =
524  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
525  if (alloc == op)
526  return WalkResult::advance();
528  return WalkResult::interrupt();
529  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
530  return WalkResult::skip();
531  return WalkResult::advance();
532  }).wasInterrupted();
533 
534  // If this contains no potential allocation, it is always legal to
535  // inline. Otherwise, consider two conditions:
536  if (hasPotentialAlloca) {
537  // If the parent isn't an allocation scope, or we are not the last
538  // non-terminator op in the parent, we will extend the lifetime.
540  return failure();
541  if (!lastNonTerminatorInRegion(op))
542  return failure();
543  }
544 
545  Block *block = &op.getRegion().front();
546  Operation *terminator = block->getTerminator();
547  ValueRange results = terminator->getOperands();
548  rewriter.inlineBlockBefore(block, op);
549  rewriter.replaceOp(op, results);
550  rewriter.eraseOp(terminator);
551  return success();
552  }
553 };
554 
555 /// Move allocations into an allocation scope, if it is legal to
556 /// move them (e.g. their operands are available at the location
557 /// the op would be moved to).
558 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
560 
561  LogicalResult matchAndRewrite(AllocaScopeOp op,
562  PatternRewriter &rewriter) const override {
563 
565  return failure();
566 
567  Operation *lastParentWithoutScope = op->getParentOp();
568 
569  if (!lastParentWithoutScope ||
570  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
571  return failure();
572 
573  // Only apply to if this is this last non-terminator
574  // op in the block (lest lifetime be extended) of a one
575  // block region
576  if (!lastNonTerminatorInRegion(op) ||
577  !lastNonTerminatorInRegion(lastParentWithoutScope))
578  return failure();
579 
580  while (!lastParentWithoutScope->getParentOp()
582  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
583  if (!lastParentWithoutScope ||
584  !lastNonTerminatorInRegion(lastParentWithoutScope))
585  return failure();
586  }
587  assert(lastParentWithoutScope->getParentOp()
589 
590  Region *containingRegion = nullptr;
591  for (auto &r : lastParentWithoutScope->getRegions()) {
592  if (r.isAncestor(op->getParentRegion())) {
593  assert(containingRegion == nullptr &&
594  "only one region can contain the op");
595  containingRegion = &r;
596  }
597  }
598  assert(containingRegion && "op must be contained in a region");
599 
600  SmallVector<Operation *> toHoist;
601  op->walk([&](Operation *alloc) {
603  return WalkResult::skip();
604 
605  // If any operand is not defined before the location of
606  // lastParentWithoutScope (i.e. where we would hoist to), skip.
607  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
608  return containingRegion->isAncestor(v.getParentRegion());
609  }))
610  return WalkResult::skip();
611  toHoist.push_back(alloc);
612  return WalkResult::advance();
613  });
614 
615  if (toHoist.empty())
616  return failure();
617  rewriter.setInsertionPoint(lastParentWithoutScope);
618  for (auto *op : toHoist) {
619  auto *cloned = rewriter.clone(*op);
620  rewriter.replaceOp(op, cloned->getResults());
621  }
622  return success();
623  }
624 };
625 
626 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
627  MLIRContext *context) {
628  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // AssumeAlignmentOp
633 //===----------------------------------------------------------------------===//
634 
636  if (!llvm::isPowerOf2_32(getAlignment()))
637  return emitOpError("alignment must be power of 2");
638  return success();
639 }
640 
641 //===----------------------------------------------------------------------===//
642 // CastOp
643 //===----------------------------------------------------------------------===//
644 
645 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
646  setNameFn(getResult(), "cast");
647 }
648 
649 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
650 /// source memref. This is useful to fold a memref.cast into a consuming op
651 /// and implement canonicalization patterns for ops in different dialects that
652 /// may consume the results of memref.cast operations. Such foldable memref.cast
653 /// operations are typically inserted as `view` and `subview` ops are
654 /// canonicalized, to preserve the type compatibility of their uses.
655 ///
656 /// Returns true when all conditions are met:
657 /// 1. source and result are ranked memrefs with strided semantics and same
658 /// element type and rank.
659 /// 2. each of the source's size, offset or stride has more static information
660 /// than the corresponding result's size, offset or stride.
661 ///
662 /// Example 1:
663 /// ```mlir
664 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
665 /// %2 = consumer %1 ... : memref<?x?xf32> ...
666 /// ```
667 ///
668 /// may fold into:
669 ///
670 /// ```mlir
671 /// %2 = consumer %0 ... : memref<8x16xf32> ...
672 /// ```
673 ///
674 /// Example 2:
675 /// ```
676 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
677 /// to memref<?x?xf32>
678 /// consumer %1 : memref<?x?xf32> ...
679 /// ```
680 ///
681 /// may fold into:
682 ///
683 /// ```
684 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
685 /// ```
686 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
687  MemRefType sourceType =
688  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
689  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
690 
691  // Requires ranked MemRefType.
692  if (!sourceType || !resultType)
693  return false;
694 
695  // Requires same elemental type.
696  if (sourceType.getElementType() != resultType.getElementType())
697  return false;
698 
699  // Requires same rank.
700  if (sourceType.getRank() != resultType.getRank())
701  return false;
702 
703  // Only fold casts between strided memref forms.
704  int64_t sourceOffset, resultOffset;
705  SmallVector<int64_t, 4> sourceStrides, resultStrides;
706  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
707  failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
708  return false;
709 
710  // If cast is towards more static sizes along any dimension, don't fold.
711  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
712  auto ss = std::get<0>(it), st = std::get<1>(it);
713  if (ss != st)
714  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
715  return false;
716  }
717 
718  // If cast is towards more static offset along any dimension, don't fold.
719  if (sourceOffset != resultOffset)
720  if (ShapedType::isDynamic(sourceOffset) &&
721  !ShapedType::isDynamic(resultOffset))
722  return false;
723 
724  // If cast is towards more static strides along any dimension, don't fold.
725  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
726  auto ss = std::get<0>(it), st = std::get<1>(it);
727  if (ss != st)
728  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
729  return false;
730  }
731 
732  return true;
733 }
734 
735 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
736  if (inputs.size() != 1 || outputs.size() != 1)
737  return false;
738  Type a = inputs.front(), b = outputs.front();
739  auto aT = llvm::dyn_cast<MemRefType>(a);
740  auto bT = llvm::dyn_cast<MemRefType>(b);
741 
742  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
743  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
744 
745  if (aT && bT) {
746  if (aT.getElementType() != bT.getElementType())
747  return false;
748  if (aT.getLayout() != bT.getLayout()) {
749  int64_t aOffset, bOffset;
750  SmallVector<int64_t, 4> aStrides, bStrides;
751  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
752  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
753  aStrides.size() != bStrides.size())
754  return false;
755 
756  // Strides along a dimension/offset are compatible if the value in the
757  // source memref is static and the value in the target memref is the
758  // same. They are also compatible if either one is dynamic (see
759  // description of MemRefCastOp for details).
760  auto checkCompatible = [](int64_t a, int64_t b) {
761  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
762  };
763  if (!checkCompatible(aOffset, bOffset))
764  return false;
765  for (const auto &aStride : enumerate(aStrides))
766  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
767  return false;
768  }
769  if (aT.getMemorySpace() != bT.getMemorySpace())
770  return false;
771 
772  // They must have the same rank, and any specified dimensions must match.
773  if (aT.getRank() != bT.getRank())
774  return false;
775 
776  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
777  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
778  if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
779  aDim != bDim)
780  return false;
781  }
782  return true;
783  } else {
784  if (!aT && !uaT)
785  return false;
786  if (!bT && !ubT)
787  return false;
788  // Unranked to unranked casting is unsupported
789  if (uaT && ubT)
790  return false;
791 
792  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
793  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
794  if (aEltType != bEltType)
795  return false;
796 
797  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
798  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
799  return aMemSpace == bMemSpace;
800  }
801 
802  return false;
803 }
804 
805 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
806  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // CopyOp
811 //===----------------------------------------------------------------------===//
812 
813 namespace {
814 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
815 /// and element type, the cast can be skipped. Such CastOps only cast the layout
816 /// of the type.
817 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
819 
820  LogicalResult matchAndRewrite(CopyOp copyOp,
821  PatternRewriter &rewriter) const override {
822  bool modified = false;
823 
824  // Check source.
825  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
826  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
827  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
828 
829  if (fromType && toType) {
830  if (fromType.getShape() == toType.getShape() &&
831  fromType.getElementType() == toType.getElementType()) {
832  rewriter.updateRootInPlace(copyOp, [&] {
833  copyOp.getSourceMutable().assign(castOp.getSource());
834  });
835  modified = true;
836  }
837  }
838  }
839 
840  // Check target.
841  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
842  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
843  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
844 
845  if (fromType && toType) {
846  if (fromType.getShape() == toType.getShape() &&
847  fromType.getElementType() == toType.getElementType()) {
848  rewriter.updateRootInPlace(copyOp, [&] {
849  copyOp.getTargetMutable().assign(castOp.getSource());
850  });
851  modified = true;
852  }
853  }
854  }
855 
856  return success(modified);
857  }
858 };
859 
860 /// Fold memref.copy(%x, %x).
861 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
863 
864  LogicalResult matchAndRewrite(CopyOp copyOp,
865  PatternRewriter &rewriter) const override {
866  if (copyOp.getSource() != copyOp.getTarget())
867  return failure();
868 
869  rewriter.eraseOp(copyOp);
870  return success();
871  }
872 };
873 } // namespace
874 
875 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
876  MLIRContext *context) {
877  results.add<FoldCopyOfCast, FoldSelfCopy>(context);
878 }
879 
880 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
882  /// copy(memrefcast) -> copy
883  bool folded = false;
884  Operation *op = *this;
885  for (OpOperand &operand : op->getOpOperands()) {
886  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
887  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
888  operand.set(castOp.getOperand());
889  folded = true;
890  }
891  }
892  return success(folded);
893 }
894 
895 //===----------------------------------------------------------------------===//
896 // DeallocOp
897 //===----------------------------------------------------------------------===//
898 
899 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
901  /// dealloc(memrefcast) -> dealloc
902  return foldMemRefCast(*this);
903 }
904 
905 //===----------------------------------------------------------------------===//
906 // DimOp
907 //===----------------------------------------------------------------------===//
908 
909 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
910  setNameFn(getResult(), "dim");
911 }
912 
913 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
914  int64_t index) {
915  auto loc = result.location;
916  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
917  build(builder, result, source, indexValue);
918 }
919 
920 std::optional<int64_t> DimOp::getConstantIndex() {
921  return getConstantIntValue(getIndex());
922 }
923 
924 Speculation::Speculatability DimOp::getSpeculatability() {
925  auto constantIndex = getConstantIndex();
926  if (!constantIndex)
928 
929  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
930  if (!rankedSourceType)
932 
933  // The verifier rejects operations that violate this assertion.
934  assert(constantIndex < rankedSourceType.getRank());
936 }
937 
938 /// Return a map with key being elements in `vals` and data being number of
939 /// occurences of it. Use std::map, since the `vals` here are strides and the
940 /// dynamic stride value is the same as the tombstone value for
941 /// `DenseMap<int64_t>`.
942 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
943  std::map<int64_t, unsigned> numOccurences;
944  for (auto val : vals)
945  numOccurences[val]++;
946  return numOccurences;
947 }
948 
949 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
950 /// to be a subset of `originalType` with some `1` entries erased, return the
951 /// set of indices that specifies which of the entries of `originalShape` are
952 /// dropped to obtain `reducedShape`.
953 /// This accounts for cases where there are multiple unit-dims, but only a
954 /// subset of those are dropped. For MemRefTypes these can be disambiguated
955 /// using the strides. If a dimension is dropped the stride must be dropped too.
956 static std::optional<llvm::SmallBitVector>
957 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
958  ArrayRef<OpFoldResult> sizes) {
959  llvm::SmallBitVector unusedDims(originalType.getRank());
960  if (originalType.getRank() == reducedType.getRank())
961  return unusedDims;
962 
963  for (const auto &dim : llvm::enumerate(sizes))
964  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
965  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
966  unusedDims.set(dim.index());
967 
968  // Early exit for the case where the number of unused dims matches the number
969  // of ranks reduced.
970  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
971  originalType.getRank())
972  return unusedDims;
973 
974  SmallVector<int64_t> originalStrides, candidateStrides;
975  int64_t originalOffset, candidateOffset;
976  if (failed(
977  getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
978  failed(
979  getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
980  return std::nullopt;
981 
982  // For memrefs, a dimension is truly dropped if its corresponding stride is
983  // also dropped. This is particularly important when more than one of the dims
984  // is 1. Track the number of occurences of the strides in the original type
985  // and the candidate type. For each unused dim that stride should not be
986  // present in the candidate type. Note that there could be multiple dimensions
987  // that have the same size. We dont need to exactly figure out which dim
988  // corresponds to which stride, we just need to verify that the number of
989  // reptitions of a stride in the original + number of unused dims with that
990  // stride == number of repititions of a stride in the candidate.
991  std::map<int64_t, unsigned> currUnaccountedStrides =
992  getNumOccurences(originalStrides);
993  std::map<int64_t, unsigned> candidateStridesNumOccurences =
994  getNumOccurences(candidateStrides);
995  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
996  if (!unusedDims.test(dim))
997  continue;
998  int64_t originalStride = originalStrides[dim];
999  if (currUnaccountedStrides[originalStride] >
1000  candidateStridesNumOccurences[originalStride]) {
1001  // This dim can be treated as dropped.
1002  currUnaccountedStrides[originalStride]--;
1003  continue;
1004  }
1005  if (currUnaccountedStrides[originalStride] ==
1006  candidateStridesNumOccurences[originalStride]) {
1007  // The stride for this is not dropped. Keep as is.
1008  unusedDims.reset(dim);
1009  continue;
1010  }
1011  if (currUnaccountedStrides[originalStride] <
1012  candidateStridesNumOccurences[originalStride]) {
1013  // This should never happen. Cant have a stride in the reduced rank type
1014  // that wasnt in the original one.
1015  return std::nullopt;
1016  }
1017  }
1018 
1019  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1020  originalType.getRank())
1021  return std::nullopt;
1022  return unusedDims;
1023 }
1024 
1025 llvm::SmallBitVector SubViewOp::getDroppedDims() {
1026  MemRefType sourceType = getSourceType();
1027  MemRefType resultType = getType();
1028  std::optional<llvm::SmallBitVector> unusedDims =
1029  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1030  assert(unusedDims && "unable to find unused dims of subview");
1031  return *unusedDims;
1032 }
1033 
1034 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1035  // All forms of folding require a known index.
1036  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1037  if (!index)
1038  return {};
1039 
1040  // Folding for unranked types (UnrankedMemRefType) is not supported.
1041  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1042  if (!memrefType)
1043  return {};
1044 
1045  // Out of bound indices produce undefined behavior but are still valid IR.
1046  // Don't choke on them.
1047  int64_t indexVal = index.getInt();
1048  if (indexVal < 0 || indexVal >= memrefType.getRank())
1049  return {};
1050 
1051  // Fold if the shape extent along the given index is known.
1052  if (!memrefType.isDynamicDim(index.getInt())) {
1053  Builder builder(getContext());
1054  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1055  }
1056 
1057  // The size at the given index is now known to be a dynamic size.
1058  unsigned unsignedIndex = index.getValue().getZExtValue();
1059 
1060  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1061  Operation *definingOp = getSource().getDefiningOp();
1062 
1063  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1064  return *(alloc.getDynamicSizes().begin() +
1065  memrefType.getDynamicDimIndex(unsignedIndex));
1066 
1067  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1068  return *(alloca.getDynamicSizes().begin() +
1069  memrefType.getDynamicDimIndex(unsignedIndex));
1070 
1071  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1072  return *(view.getDynamicSizes().begin() +
1073  memrefType.getDynamicDimIndex(unsignedIndex));
1074 
1075  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1076  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1077  unsigned resultIndex = 0;
1078  unsigned sourceRank = subview.getSourceType().getRank();
1079  unsigned sourceIndex = 0;
1080  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1081  if (unusedDims.test(i))
1082  continue;
1083  if (resultIndex == unsignedIndex) {
1084  sourceIndex = i;
1085  break;
1086  }
1087  resultIndex++;
1088  }
1089  assert(subview.isDynamicSize(sourceIndex) &&
1090  "expected dynamic subview size");
1091  return subview.getDynamicSize(sourceIndex);
1092  }
1093 
1094  if (auto sizeInterface =
1095  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1096  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1097  "Expected dynamic subview size");
1098  return sizeInterface.getDynamicSize(unsignedIndex);
1099  }
1100 
1101  // dim(memrefcast) -> dim
1102  if (succeeded(foldMemRefCast(*this)))
1103  return getResult();
1104 
1105  return {};
1106 }
1107 
1108 namespace {
1109 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1110 /// operand.
1111 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1113 
1114  LogicalResult matchAndRewrite(DimOp dim,
1115  PatternRewriter &rewriter) const override {
1116  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1117 
1118  if (!reshape)
1119  return failure();
1120 
1121  // Place the load directly after the reshape to ensure that the shape memref
1122  // was not mutated.
1123  rewriter.setInsertionPointAfter(reshape);
1124  Location loc = dim.getLoc();
1125  Value load =
1126  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1127  if (load.getType() != dim.getType())
1128  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1129  rewriter.replaceOp(dim, load);
1130  return success();
1131  }
1132 };
1133 
1134 } // namespace
1135 
1136 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1137  MLIRContext *context) {
1138  results.add<DimOfMemRefReshape>(context);
1139 }
1140 
1141 // ---------------------------------------------------------------------------
1142 // DmaStartOp
1143 // ---------------------------------------------------------------------------
1144 
1145 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1146  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1147  ValueRange destIndices, Value numElements,
1148  Value tagMemRef, ValueRange tagIndices, Value stride,
1149  Value elementsPerStride) {
1150  result.addOperands(srcMemRef);
1151  result.addOperands(srcIndices);
1152  result.addOperands(destMemRef);
1153  result.addOperands(destIndices);
1154  result.addOperands({numElements, tagMemRef});
1155  result.addOperands(tagIndices);
1156  if (stride)
1157  result.addOperands({stride, elementsPerStride});
1158 }
1159 
1161  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1162  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1163  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1164  if (isStrided())
1165  p << ", " << getStride() << ", " << getNumElementsPerStride();
1166 
1167  p.printOptionalAttrDict((*this)->getAttrs());
1168  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1169  << ", " << getTagMemRef().getType();
1170 }
1171 
1172 // Parse DmaStartOp.
1173 // Ex:
1174 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1175 // %tag[%index], %stride, %num_elt_per_stride :
1176 // : memref<3076 x f32, 0>,
1177 // memref<1024 x f32, 2>,
1178 // memref<1 x i32>
1179 //
1180 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1181  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1183  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1185  OpAsmParser::UnresolvedOperand numElementsInfo;
1186  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1189 
1190  SmallVector<Type, 3> types;
1191  auto indexType = parser.getBuilder().getIndexType();
1192 
1193  // Parse and resolve the following list of operands:
1194  // *) source memref followed by its indices (in square brackets).
1195  // *) destination memref followed by its indices (in square brackets).
1196  // *) dma size in KiB.
1197  if (parser.parseOperand(srcMemRefInfo) ||
1198  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1199  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1200  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1201  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1202  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1203  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1204  return failure();
1205 
1206  // Parse optional stride and elements per stride.
1207  if (parser.parseTrailingOperandList(strideInfo))
1208  return failure();
1209 
1210  bool isStrided = strideInfo.size() == 2;
1211  if (!strideInfo.empty() && !isStrided) {
1212  return parser.emitError(parser.getNameLoc(),
1213  "expected two stride related operands");
1214  }
1215 
1216  if (parser.parseColonTypeList(types))
1217  return failure();
1218  if (types.size() != 3)
1219  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1220 
1221  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1222  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1223  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1224  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1225  // size should be an index.
1226  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1227  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1228  // tag indices should be index.
1229  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1230  return failure();
1231 
1232  if (isStrided) {
1233  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1234  return failure();
1235  }
1236 
1237  return success();
1238 }
1239 
1241  unsigned numOperands = getNumOperands();
1242 
1243  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1244  // the number of elements.
1245  if (numOperands < 4)
1246  return emitOpError("expected at least 4 operands");
1247 
1248  // Check types of operands. The order of these calls is important: the later
1249  // calls rely on some type properties to compute the operand position.
1250  // 1. Source memref.
1251  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1252  return emitOpError("expected source to be of memref type");
1253  if (numOperands < getSrcMemRefRank() + 4)
1254  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1255  << " operands";
1256  if (!getSrcIndices().empty() &&
1257  !llvm::all_of(getSrcIndices().getTypes(),
1258  [](Type t) { return t.isIndex(); }))
1259  return emitOpError("expected source indices to be of index type");
1260 
1261  // 2. Destination memref.
1262  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1263  return emitOpError("expected destination to be of memref type");
1264  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1265  if (numOperands < numExpectedOperands)
1266  return emitOpError() << "expected at least " << numExpectedOperands
1267  << " operands";
1268  if (!getDstIndices().empty() &&
1269  !llvm::all_of(getDstIndices().getTypes(),
1270  [](Type t) { return t.isIndex(); }))
1271  return emitOpError("expected destination indices to be of index type");
1272 
1273  // 3. Number of elements.
1274  if (!getNumElements().getType().isIndex())
1275  return emitOpError("expected num elements to be of index type");
1276 
1277  // 4. Tag memref.
1278  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1279  return emitOpError("expected tag to be of memref type");
1280  numExpectedOperands += getTagMemRefRank();
1281  if (numOperands < numExpectedOperands)
1282  return emitOpError() << "expected at least " << numExpectedOperands
1283  << " operands";
1284  if (!getTagIndices().empty() &&
1285  !llvm::all_of(getTagIndices().getTypes(),
1286  [](Type t) { return t.isIndex(); }))
1287  return emitOpError("expected tag indices to be of index type");
1288 
1289  // Optional stride-related operands must be either both present or both
1290  // absent.
1291  if (numOperands != numExpectedOperands &&
1292  numOperands != numExpectedOperands + 2)
1293  return emitOpError("incorrect number of operands");
1294 
1295  // 5. Strides.
1296  if (isStrided()) {
1297  if (!getStride().getType().isIndex() ||
1298  !getNumElementsPerStride().getType().isIndex())
1299  return emitOpError(
1300  "expected stride and num elements per stride to be of type index");
1301  }
1302 
1303  return success();
1304 }
1305 
1306 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1307  SmallVectorImpl<OpFoldResult> &results) {
1308  /// dma_start(memrefcast) -> dma_start
1309  return foldMemRefCast(*this);
1310 }
1311 
1312 // ---------------------------------------------------------------------------
1313 // DmaWaitOp
1314 // ---------------------------------------------------------------------------
1315 
1316 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1317  SmallVectorImpl<OpFoldResult> &results) {
1318  /// dma_wait(memrefcast) -> dma_wait
1319  return foldMemRefCast(*this);
1320 }
1321 
1323  // Check that the number of tag indices matches the tagMemRef rank.
1324  unsigned numTagIndices = getTagIndices().size();
1325  unsigned tagMemRefRank = getTagMemRefRank();
1326  if (numTagIndices != tagMemRefRank)
1327  return emitOpError() << "expected tagIndices to have the same number of "
1328  "elements as the tagMemRef rank, expected "
1329  << tagMemRefRank << ", but got " << numTagIndices;
1330  return success();
1331 }
1332 
1333 //===----------------------------------------------------------------------===//
1334 // ExtractAlignedPointerAsIndexOp
1335 //===----------------------------------------------------------------------===//
1336 
1337 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1338  function_ref<void(Value, StringRef)> setNameFn) {
1339  setNameFn(getResult(), "intptr");
1340 }
1341 
1342 //===----------------------------------------------------------------------===//
1343 // ExtractStridedMetadataOp
1344 //===----------------------------------------------------------------------===//
1345 
1346 /// The number and type of the results are inferred from the
1347 /// shape of the source.
1348 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1349  MLIRContext *context, std::optional<Location> location,
1350  ExtractStridedMetadataOp::Adaptor adaptor,
1351  SmallVectorImpl<Type> &inferredReturnTypes) {
1352  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1353  if (!sourceType)
1354  return failure();
1355 
1356  unsigned sourceRank = sourceType.getRank();
1357  IndexType indexType = IndexType::get(context);
1358  auto memrefType =
1359  MemRefType::get({}, sourceType.getElementType(),
1360  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1361  // Base.
1362  inferredReturnTypes.push_back(memrefType);
1363  // Offset.
1364  inferredReturnTypes.push_back(indexType);
1365  // Sizes and strides.
1366  for (unsigned i = 0; i < sourceRank * 2; ++i)
1367  inferredReturnTypes.push_back(indexType);
1368  return success();
1369 }
1370 
1371 void ExtractStridedMetadataOp::getAsmResultNames(
1372  function_ref<void(Value, StringRef)> setNameFn) {
1373  setNameFn(getBaseBuffer(), "base_buffer");
1374  setNameFn(getOffset(), "offset");
1375  // For multi-result to work properly with pretty names and packed syntax `x:3`
1376  // we can only give a pretty name to the first value in the pack.
1377  if (!getSizes().empty()) {
1378  setNameFn(getSizes().front(), "sizes");
1379  setNameFn(getStrides().front(), "strides");
1380  }
1381 }
1382 
1383 /// Helper function to perform the replacement of all constant uses of `values`
1384 /// by a materialized constant extracted from `maybeConstants`.
1385 /// `values` and `maybeConstants` are expected to have the same size.
1386 template <typename Container>
1387 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1388  Container values,
1389  ArrayRef<OpFoldResult> maybeConstants) {
1390  assert(values.size() == maybeConstants.size() &&
1391  " expected values and maybeConstants of the same size");
1392  bool atLeastOneReplacement = false;
1393  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1394  // Don't materialize a constant if there are no uses: this would indice
1395  // infinite loops in the driver.
1396  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1397  continue;
1398  assert(maybeConstant.template is<Attribute>() &&
1399  "The constified value should be either unchanged (i.e., == result) "
1400  "or a constant");
1401  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1402  loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
1403  .getInt());
1404  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1405  // updateRootInplace: lambda cannot capture structured bindings in C++17
1406  // yet.
1407  op->replaceUsesOfWith(result, constantVal);
1408  atLeastOneReplacement = true;
1409  }
1410  }
1411  return atLeastOneReplacement;
1412 }
1413 
1415 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1416  SmallVectorImpl<OpFoldResult> &results) {
1417  OpBuilder builder(*this);
1418 
1419  bool atLeastOneReplacement = replaceConstantUsesOf(
1420  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1421  getConstifiedMixedOffset());
1422  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1423  getConstifiedMixedSizes());
1424  atLeastOneReplacement |= replaceConstantUsesOf(
1425  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1426 
1427  return success(atLeastOneReplacement);
1428 }
1429 
1430 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1431  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1432  constifyIndexValues(values, getSource().getType(), getContext(),
1433  getConstantSizes, ShapedType::isDynamic);
1434  return values;
1435 }
1436 
1438 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1439  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1440  constifyIndexValues(values, getSource().getType(), getContext(),
1441  getConstantStrides, ShapedType::isDynamic);
1442  return values;
1443 }
1444 
1445 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1446  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1447  SmallVector<OpFoldResult> values(1, offsetOfr);
1448  constifyIndexValues(values, getSource().getType(), getContext(),
1449  getConstantOffset, ShapedType::isDynamic);
1450  return values[0];
1451 }
1452 
1453 //===----------------------------------------------------------------------===//
1454 // GenericAtomicRMWOp
1455 //===----------------------------------------------------------------------===//
1456 
1457 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1458  Value memref, ValueRange ivs) {
1459  result.addOperands(memref);
1460  result.addOperands(ivs);
1461 
1462  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1463  Type elementType = memrefType.getElementType();
1464  result.addTypes(elementType);
1465 
1466  Region *bodyRegion = result.addRegion();
1467  bodyRegion->push_back(new Block());
1468  bodyRegion->addArgument(elementType, memref.getLoc());
1469  }
1470 }
1471 
1473  auto &body = getRegion();
1474  if (body.getNumArguments() != 1)
1475  return emitOpError("expected single number of entry block arguments");
1476 
1477  if (getResult().getType() != body.getArgument(0).getType())
1478  return emitOpError("expected block argument of the same type result type");
1479 
1480  bool hasSideEffects =
1481  body.walk([&](Operation *nestedOp) {
1482  if (isMemoryEffectFree(nestedOp))
1483  return WalkResult::advance();
1484  nestedOp->emitError(
1485  "body of 'memref.generic_atomic_rmw' should contain "
1486  "only operations with no side effects");
1487  return WalkResult::interrupt();
1488  })
1489  .wasInterrupted();
1490  return hasSideEffects ? failure() : success();
1491 }
1492 
1493 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1494  OperationState &result) {
1496  Type memrefType;
1498 
1499  Type indexType = parser.getBuilder().getIndexType();
1500  if (parser.parseOperand(memref) ||
1502  parser.parseColonType(memrefType) ||
1503  parser.resolveOperand(memref, memrefType, result.operands) ||
1504  parser.resolveOperands(ivs, indexType, result.operands))
1505  return failure();
1506 
1507  Region *body = result.addRegion();
1508  if (parser.parseRegion(*body, {}) ||
1509  parser.parseOptionalAttrDict(result.attributes))
1510  return failure();
1511  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1512  return success();
1513 }
1514 
1516  p << ' ' << getMemref() << "[" << getIndices()
1517  << "] : " << getMemref().getType() << ' ';
1518  p.printRegion(getRegion());
1519  p.printOptionalAttrDict((*this)->getAttrs());
1520 }
1521 
1522 //===----------------------------------------------------------------------===//
1523 // AtomicYieldOp
1524 //===----------------------------------------------------------------------===//
1525 
1527  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1528  Type resultType = getResult().getType();
1529  if (parentType != resultType)
1530  return emitOpError() << "types mismatch between yield op: " << resultType
1531  << " and its parent: " << parentType;
1532  return success();
1533 }
1534 
1535 //===----------------------------------------------------------------------===//
1536 // GlobalOp
1537 //===----------------------------------------------------------------------===//
1538 
1540  TypeAttr type,
1541  Attribute initialValue) {
1542  p << type;
1543  if (!op.isExternal()) {
1544  p << " = ";
1545  if (op.isUninitialized())
1546  p << "uninitialized";
1547  else
1548  p.printAttributeWithoutType(initialValue);
1549  }
1550 }
1551 
1552 static ParseResult
1554  Attribute &initialValue) {
1555  Type type;
1556  if (parser.parseType(type))
1557  return failure();
1558 
1559  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1560  if (!memrefType || !memrefType.hasStaticShape())
1561  return parser.emitError(parser.getNameLoc())
1562  << "type should be static shaped memref, but got " << type;
1563  typeAttr = TypeAttr::get(type);
1564 
1565  if (parser.parseOptionalEqual())
1566  return success();
1567 
1568  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1569  initialValue = UnitAttr::get(parser.getContext());
1570  return success();
1571  }
1572 
1573  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1574  if (parser.parseAttribute(initialValue, tensorType))
1575  return failure();
1576  if (!llvm::isa<ElementsAttr>(initialValue))
1577  return parser.emitError(parser.getNameLoc())
1578  << "initial value should be a unit or elements attribute";
1579  return success();
1580 }
1581 
1583  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1584  if (!memrefType || !memrefType.hasStaticShape())
1585  return emitOpError("type should be static shaped memref, but got ")
1586  << getType();
1587 
1588  // Verify that the initial value, if present, is either a unit attribute or
1589  // an elements attribute.
1590  if (getInitialValue().has_value()) {
1591  Attribute initValue = getInitialValue().value();
1592  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1593  return emitOpError("initial value should be a unit or elements "
1594  "attribute, but got ")
1595  << initValue;
1596 
1597  // Check that the type of the initial value is compatible with the type of
1598  // the global variable.
1599  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1600  Type initType = elementsAttr.getType();
1601  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1602  if (initType != tensorType)
1603  return emitOpError("initial value expected to be of type ")
1604  << tensorType << ", but was of type " << initType;
1605  }
1606  }
1607 
1608  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1609  uint64_t alignment = *alignAttr;
1610 
1611  if (!llvm::isPowerOf2_64(alignment))
1612  return emitError() << "alignment attribute value " << alignment
1613  << " is not a power of 2";
1614  }
1615 
1616  // TODO: verify visibility for declarations.
1617  return success();
1618 }
1619 
1620 ElementsAttr GlobalOp::getConstantInitValue() {
1621  auto initVal = getInitialValue();
1622  if (getConstant() && initVal.has_value())
1623  return llvm::cast<ElementsAttr>(initVal.value());
1624  return {};
1625 }
1626 
1627 //===----------------------------------------------------------------------===//
1628 // GetGlobalOp
1629 //===----------------------------------------------------------------------===//
1630 
1632 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1633  // Verify that the result type is same as the type of the referenced
1634  // memref.global op.
1635  auto global =
1636  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1637  if (!global)
1638  return emitOpError("'")
1639  << getName() << "' does not reference a valid global memref";
1640 
1641  Type resultType = getResult().getType();
1642  if (global.getType() != resultType)
1643  return emitOpError("result type ")
1644  << resultType << " does not match type " << global.getType()
1645  << " of the global memref @" << getName();
1646  return success();
1647 }
1648 
1649 //===----------------------------------------------------------------------===//
1650 // LoadOp
1651 //===----------------------------------------------------------------------===//
1652 
1654  if (getNumOperands() != 1 + getMemRefType().getRank())
1655  return emitOpError("incorrect number of indices for load");
1656  return success();
1657 }
1658 
1659 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1660  /// load(memrefcast) -> load
1661  if (succeeded(foldMemRefCast(*this)))
1662  return getResult();
1663  return OpFoldResult();
1664 }
1665 
1666 //===----------------------------------------------------------------------===//
1667 // MemorySpaceCastOp
1668 //===----------------------------------------------------------------------===//
1669 
1670 void MemorySpaceCastOp::getAsmResultNames(
1671  function_ref<void(Value, StringRef)> setNameFn) {
1672  setNameFn(getResult(), "memspacecast");
1673 }
1674 
1675 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1676  if (inputs.size() != 1 || outputs.size() != 1)
1677  return false;
1678  Type a = inputs.front(), b = outputs.front();
1679  auto aT = llvm::dyn_cast<MemRefType>(a);
1680  auto bT = llvm::dyn_cast<MemRefType>(b);
1681 
1682  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1683  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1684 
1685  if (aT && bT) {
1686  if (aT.getElementType() != bT.getElementType())
1687  return false;
1688  if (aT.getLayout() != bT.getLayout())
1689  return false;
1690  if (aT.getShape() != bT.getShape())
1691  return false;
1692  return true;
1693  }
1694  if (uaT && ubT) {
1695  return uaT.getElementType() == ubT.getElementType();
1696  }
1697  return false;
1698 }
1699 
1700 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1701  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1702  // t2)
1703  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1704  getSourceMutable().assign(parentCast.getSource());
1705  return getResult();
1706  }
1707  return Value{};
1708 }
1709 
1710 //===----------------------------------------------------------------------===//
1711 // PrefetchOp
1712 //===----------------------------------------------------------------------===//
1713 
1715  p << " " << getMemref() << '[';
1717  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1718  p << ", locality<" << getLocalityHint();
1719  p << ">, " << (getIsDataCache() ? "data" : "instr");
1721  (*this)->getAttrs(),
1722  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1723  p << " : " << getMemRefType();
1724 }
1725 
1726 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1727  OpAsmParser::UnresolvedOperand memrefInfo;
1729  IntegerAttr localityHint;
1730  MemRefType type;
1731  StringRef readOrWrite, cacheType;
1732 
1733  auto indexTy = parser.getBuilder().getIndexType();
1734  auto i32Type = parser.getBuilder().getIntegerType(32);
1735  if (parser.parseOperand(memrefInfo) ||
1736  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1737  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1738  parser.parseComma() || parser.parseKeyword("locality") ||
1739  parser.parseLess() ||
1740  parser.parseAttribute(localityHint, i32Type, "localityHint",
1741  result.attributes) ||
1742  parser.parseGreater() || parser.parseComma() ||
1743  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1744  parser.resolveOperand(memrefInfo, type, result.operands) ||
1745  parser.resolveOperands(indexInfo, indexTy, result.operands))
1746  return failure();
1747 
1748  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1749  return parser.emitError(parser.getNameLoc(),
1750  "rw specifier has to be 'read' or 'write'");
1751  result.addAttribute(
1752  PrefetchOp::getIsWriteAttrStrName(),
1753  parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1754 
1755  if (!cacheType.equals("data") && !cacheType.equals("instr"))
1756  return parser.emitError(parser.getNameLoc(),
1757  "cache type has to be 'data' or 'instr'");
1758 
1759  result.addAttribute(
1760  PrefetchOp::getIsDataCacheAttrStrName(),
1761  parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1762 
1763  return success();
1764 }
1765 
1767  if (getNumOperands() != 1 + getMemRefType().getRank())
1768  return emitOpError("too few indices");
1769 
1770  return success();
1771 }
1772 
1773 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1774  SmallVectorImpl<OpFoldResult> &results) {
1775  // prefetch(memrefcast) -> prefetch
1776  return foldMemRefCast(*this);
1777 }
1778 
1779 //===----------------------------------------------------------------------===//
1780 // RankOp
1781 //===----------------------------------------------------------------------===//
1782 
1783 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1784  // Constant fold rank when the rank of the operand is known.
1785  auto type = getOperand().getType();
1786  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1787  if (shapedType && shapedType.hasRank())
1788  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1789  return IntegerAttr();
1790 }
1791 
1792 //===----------------------------------------------------------------------===//
1793 // ReinterpretCastOp
1794 //===----------------------------------------------------------------------===//
1795 
1796 void ReinterpretCastOp::getAsmResultNames(
1797  function_ref<void(Value, StringRef)> setNameFn) {
1798  setNameFn(getResult(), "reinterpret_cast");
1799 }
1800 
1801 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1802 /// `staticSizes` and `staticStrides` are automatically filled with
1803 /// source-memref-rank sentinel values that encode dynamic entries.
1804 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1805  MemRefType resultType, Value source,
1806  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1807  ArrayRef<OpFoldResult> strides,
1808  ArrayRef<NamedAttribute> attrs) {
1809  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1810  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1811  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1812  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1813  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1814  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1815  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1816  b.getDenseI64ArrayAttr(staticSizes),
1817  b.getDenseI64ArrayAttr(staticStrides));
1818  result.addAttributes(attrs);
1819 }
1820 
1821 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1822  MemRefType resultType, Value source,
1823  int64_t offset, ArrayRef<int64_t> sizes,
1824  ArrayRef<int64_t> strides,
1825  ArrayRef<NamedAttribute> attrs) {
1826  SmallVector<OpFoldResult> sizeValues =
1827  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1828  return b.getI64IntegerAttr(v);
1829  }));
1830  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1831  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1832  return b.getI64IntegerAttr(v);
1833  }));
1834  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1835  strideValues, attrs);
1836 }
1837 
1838 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1839  MemRefType resultType, Value source, Value offset,
1840  ValueRange sizes, ValueRange strides,
1841  ArrayRef<NamedAttribute> attrs) {
1842  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1843  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1844  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1845  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1846  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1847 }
1848 
1849 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1850 // completed automatically, like we have for subview and extract_slice.
1852  // The source and result memrefs should be in the same memory space.
1853  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1854  auto resultType = llvm::cast<MemRefType>(getType());
1855  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1856  return emitError("different memory spaces specified for source type ")
1857  << srcType << " and result memref type " << resultType;
1858  if (srcType.getElementType() != resultType.getElementType())
1859  return emitError("different element types specified for source type ")
1860  << srcType << " and result memref type " << resultType;
1861 
1862  // Match sizes in result memref type and in static_sizes attribute.
1863  for (auto [idx, resultSize, expectedSize] :
1864  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1865  if (!ShapedType::isDynamic(resultSize) &&
1866  !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1867  return emitError("expected result type with size = ")
1868  << expectedSize << " instead of " << resultSize
1869  << " in dim = " << idx;
1870  }
1871 
1872  // Match offset and strides in static_offset and static_strides attributes. If
1873  // result memref type has no affine map specified, this will assume an
1874  // identity layout.
1875  int64_t resultOffset;
1876  SmallVector<int64_t, 4> resultStrides;
1877  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1878  return emitError("expected result type to have strided layout but found ")
1879  << resultType;
1880 
1881  // Match offset in result memref type and in static_offsets attribute.
1882  int64_t expectedOffset = getStaticOffsets().front();
1883  if (!ShapedType::isDynamic(resultOffset) &&
1884  !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
1885  return emitError("expected result type with offset = ")
1886  << expectedOffset << " instead of " << resultOffset;
1887 
1888  // Match strides in result memref type and in static_strides attribute.
1889  for (auto [idx, resultStride, expectedStride] :
1890  llvm::enumerate(resultStrides, getStaticStrides())) {
1891  if (!ShapedType::isDynamic(resultStride) &&
1892  !ShapedType::isDynamic(expectedStride) &&
1893  resultStride != expectedStride)
1894  return emitError("expected result type with stride = ")
1895  << expectedStride << " instead of " << resultStride
1896  << " in dim = " << idx;
1897  }
1898 
1899  return success();
1900 }
1901 
1902 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1903  Value src = getSource();
1904  auto getPrevSrc = [&]() -> Value {
1905  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1906  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1907  return prev.getSource();
1908 
1909  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1910  if (auto prev = src.getDefiningOp<CastOp>())
1911  return prev.getSource();
1912 
1913  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1914  // are 0.
1915  if (auto prev = src.getDefiningOp<SubViewOp>())
1916  if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1917  return isConstantIntValue(val, 0);
1918  }))
1919  return prev.getSource();
1920 
1921  return nullptr;
1922  };
1923 
1924  if (auto prevSrc = getPrevSrc()) {
1925  getSourceMutable().assign(prevSrc);
1926  return getResult();
1927  }
1928 
1929  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1930  if (!ShapedType::isDynamicShape(getType().getShape()) &&
1931  src.getType() == getType() && getStaticOffsets().front() == 0) {
1932  return src;
1933  }
1934 
1935  return nullptr;
1936 }
1937 
1938 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1940  constifyIndexValues(values, getType(), getContext(), getConstantSizes,
1941  ShapedType::isDynamic);
1942  return values;
1943 }
1944 
1945 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1946  SmallVector<OpFoldResult> values = getMixedStrides();
1947  constifyIndexValues(values, getType(), getContext(), getConstantStrides,
1948  ShapedType::isDynamic);
1949  return values;
1950 }
1951 
1952 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1953  SmallVector<OpFoldResult> values = getMixedOffsets();
1954  assert(values.size() == 1 &&
1955  "reinterpret_cast must have one and only one offset");
1956  constifyIndexValues(values, getType(), getContext(), getConstantOffset,
1957  ShapedType::isDynamic);
1958  return values[0];
1959 }
1960 
1961 namespace {
1962 /// Replace the sequence:
1963 /// ```
1964 /// base, offset, sizes, strides = extract_strided_metadata src
1965 /// dst = reinterpret_cast base to offset, sizes, strides
1966 /// ```
1967 /// With
1968 ///
1969 /// ```
1970 /// dst = memref.cast src
1971 /// ```
1972 ///
1973 /// Note: The cast operation is only inserted when the type of dst and src
1974 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
1975 ///
1976 /// This pattern also matches when the offset, sizes, and strides don't come
1977 /// directly from the `extract_strided_metadata`'s results but it can be
1978 /// statically proven that they would hold the same values.
1979 ///
1980 /// For instance, the following sequence would be replaced:
1981 /// ```
1982 /// base, offset, sizes, strides =
1983 /// extract_strided_metadata memref : memref<3x4xty>
1984 /// dst = reinterpret_cast base to 0, [3, 4], strides
1985 /// ```
1986 /// Because we know (thanks to the type of the input memref) that variable
1987 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
1988 ///
1989 /// Similarly, the following sequence would be replaced:
1990 /// ```
1991 /// c0 = arith.constant 0
1992 /// c4 = arith.constant 4
1993 /// base, offset, sizes, strides =
1994 /// extract_strided_metadata memref : memref<3x4xty>
1995 /// dst = reinterpret_cast base to c0, [3, c4], strides
1996 /// ```
1997 /// Because we know that `offset`and `c0` will hold 0
1998 /// and `c4` will hold 4.
1999 struct ReinterpretCastOpExtractStridedMetadataFolder
2000  : public OpRewritePattern<ReinterpretCastOp> {
2001 public:
2003 
2004  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2005  PatternRewriter &rewriter) const override {
2006  auto extractStridedMetadata =
2007  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2008  if (!extractStridedMetadata)
2009  return failure();
2010  // Check if the reinterpret cast reconstructs a memref with the exact same
2011  // properties as the extract strided metadata.
2012 
2013  // First, check that the strides are the same.
2014  SmallVector<OpFoldResult> extractStridesOfr =
2015  extractStridedMetadata.getConstifiedMixedStrides();
2016  SmallVector<OpFoldResult> reinterpretStridesOfr =
2017  op.getConstifiedMixedStrides();
2018  if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2019  return failure();
2020 
2021  unsigned rank = op.getType().getRank();
2022  for (unsigned i = 0; i < rank; ++i) {
2023  if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2024  return failure();
2025  }
2026 
2027  // Second, check the sizes.
2028  assert(extractStridedMetadata.getSizes().size() ==
2029  op.getMixedSizes().size() &&
2030  "Strides and sizes rank must match");
2031  SmallVector<OpFoldResult> extractSizesOfr =
2032  extractStridedMetadata.getConstifiedMixedSizes();
2033  SmallVector<OpFoldResult> reinterpretSizesOfr =
2034  op.getConstifiedMixedSizes();
2035  for (unsigned i = 0; i < rank; ++i) {
2036  if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2037  return failure();
2038  }
2039  // Finally, check the offset.
2040  assert(op.getMixedOffsets().size() == 1 &&
2041  "reinterpret_cast with more than one offset should have been "
2042  "rejected by the verifier");
2043  OpFoldResult extractOffsetOfr =
2044  extractStridedMetadata.getConstifiedMixedOffset();
2045  OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2046  if (extractOffsetOfr != reinterpretOffsetOfr)
2047  return failure();
2048 
2049  // At this point, we know that the back and forth between extract strided
2050  // metadata and reinterpret cast is a noop. However, the final type of the
2051  // reinterpret cast may not be exactly the same as the original memref.
2052  // E.g., it could be changing a dimension from static to dynamic. Check that
2053  // here and add a cast if necessary.
2054  Type srcTy = extractStridedMetadata.getSource().getType();
2055  if (srcTy == op.getResult().getType())
2056  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2057  else
2058  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2059  extractStridedMetadata.getSource());
2060 
2061  return success();
2062  }
2063 };
2064 } // namespace
2065 
2066 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2067  MLIRContext *context) {
2068  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2069 }
2070 
2071 //===----------------------------------------------------------------------===//
2072 // Reassociative reshape ops
2073 //===----------------------------------------------------------------------===//
2074 
2075 void CollapseShapeOp::getAsmResultNames(
2076  function_ref<void(Value, StringRef)> setNameFn) {
2077  setNameFn(getResult(), "collapse_shape");
2078 }
2079 
2080 void ExpandShapeOp::getAsmResultNames(
2081  function_ref<void(Value, StringRef)> setNameFn) {
2082  setNameFn(getResult(), "expand_shape");
2083 }
2084 
2085 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2086 /// result and operand. Layout maps are verified separately.
2087 ///
2088 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2089 /// allowed in a reassocation group.
2090 static LogicalResult
2092  ArrayRef<int64_t> expandedShape,
2093  ArrayRef<ReassociationIndices> reassociation,
2094  bool allowMultipleDynamicDimsPerGroup) {
2095  // There must be one reassociation group per collapsed dimension.
2096  if (collapsedShape.size() != reassociation.size())
2097  return op->emitOpError("invalid number of reassociation groups: found ")
2098  << reassociation.size() << ", expected " << collapsedShape.size();
2099 
2100  // The next expected expanded dimension index (while iterating over
2101  // reassociation indices).
2102  int64_t nextDim = 0;
2103  for (const auto &it : llvm::enumerate(reassociation)) {
2104  ReassociationIndices group = it.value();
2105  int64_t collapsedDim = it.index();
2106 
2107  bool foundDynamic = false;
2108  for (int64_t expandedDim : group) {
2109  if (expandedDim != nextDim++)
2110  return op->emitOpError("reassociation indices must be contiguous");
2111 
2112  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2113  return op->emitOpError("reassociation index ")
2114  << expandedDim << " is out of bounds";
2115 
2116  // Check if there are multiple dynamic dims in a reassociation group.
2117  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2118  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2119  return op->emitOpError(
2120  "at most one dimension in a reassociation group may be dynamic");
2121  foundDynamic = true;
2122  }
2123  }
2124 
2125  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2126  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2127  return op->emitOpError("collapsed dim (")
2128  << collapsedDim
2129  << ") must be dynamic if and only if reassociation group is "
2130  "dynamic";
2131 
2132  // If all dims in the reassociation group are static, the size of the
2133  // collapsed dim can be verified.
2134  if (!foundDynamic) {
2135  int64_t groupSize = 1;
2136  for (int64_t expandedDim : group)
2137  groupSize *= expandedShape[expandedDim];
2138  if (groupSize != collapsedShape[collapsedDim])
2139  return op->emitOpError("collapsed dim size (")
2140  << collapsedShape[collapsedDim]
2141  << ") must equal reassociation group size (" << groupSize << ")";
2142  }
2143  }
2144 
2145  if (collapsedShape.empty()) {
2146  // Rank 0: All expanded dimensions must be 1.
2147  for (int64_t d : expandedShape)
2148  if (d != 1)
2149  return op->emitOpError(
2150  "rank 0 memrefs can only be extended/collapsed with/from ones");
2151  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2152  // Rank >= 1: Number of dimensions among all reassociation groups must match
2153  // the result memref rank.
2154  return op->emitOpError("expanded rank (")
2155  << expandedShape.size()
2156  << ") inconsistent with number of reassociation indices (" << nextDim
2157  << ")";
2158  }
2159 
2160  return success();
2161 }
2162 
2163 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2164  return getSymbolLessAffineMaps(getReassociationExprs());
2165 }
2166 
2167 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2169  getReassociationIndices());
2170 }
2171 
2172 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2173  return getSymbolLessAffineMaps(getReassociationExprs());
2174 }
2175 
2176 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2178  getReassociationIndices());
2179 }
2180 
2181 /// Compute the layout map after expanding a given source MemRef type with the
2182 /// specified reassociation indices.
2184 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2185  ArrayRef<ReassociationIndices> reassociation) {
2186  int64_t srcOffset;
2187  SmallVector<int64_t> srcStrides;
2188  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2189  return failure();
2190  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2191 
2192  // 1-1 mapping between srcStrides and reassociation packs.
2193  // Each srcStride starts with the given value and gets expanded according to
2194  // the proper entries in resultShape.
2195  // Example:
2196  // srcStrides = [10000, 1 , 100 ],
2197  // reassociations = [ [0], [1], [2, 3, 4]],
2198  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2199  // -> For the purpose of stride calculation, the useful sizes are:
2200  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2201  // resultStrides = [10000, 1, 600, 200, 100]
2202  // Note that a stride does not get expanded along the first entry of each
2203  // shape pack.
2204  SmallVector<int64_t> reverseResultStrides;
2205  reverseResultStrides.reserve(resultShape.size());
2206  unsigned shapeIndex = resultShape.size() - 1;
2207  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2208  ReassociationIndices reassoc = std::get<0>(it);
2209  int64_t currentStrideToExpand = std::get<1>(it);
2210  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2211  using saturated_arith::Wrapper;
2212  reverseResultStrides.push_back(currentStrideToExpand);
2213  currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
2214  Wrapper::size(resultShape[shapeIndex--]))
2215  .asStride();
2216  }
2217  }
2218  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2219  resultStrides.resize(resultShape.size(), 1);
2220  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2221 }
2222 
2223 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2224  MemRefType srcType, ArrayRef<int64_t> resultShape,
2225  ArrayRef<ReassociationIndices> reassociation) {
2226  if (srcType.getLayout().isIdentity()) {
2227  // If the source is contiguous (i.e., no layout map specified), so is the
2228  // result.
2229  MemRefLayoutAttrInterface layout;
2230  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2231  srcType.getMemorySpace());
2232  }
2233 
2234  // Source may not be contiguous. Compute the layout map.
2235  FailureOr<StridedLayoutAttr> computedLayout =
2236  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2237  if (failed(computedLayout))
2238  return failure();
2239  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2240  srcType.getMemorySpace());
2241 }
2242 
2243 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2244  ArrayRef<int64_t> resultShape, Value src,
2245  ArrayRef<ReassociationIndices> reassociation) {
2246  // Only ranked memref source values are supported.
2247  auto srcType = llvm::cast<MemRefType>(src.getType());
2248  FailureOr<MemRefType> resultType =
2249  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2250  // Failure of this assertion usually indicates a problem with the source
2251  // type, e.g., could not get strides/offset.
2252  assert(succeeded(resultType) && "could not compute layout");
2253  build(builder, result, *resultType, src, reassociation);
2254 }
2255 
2257  MemRefType srcType = getSrcType();
2258  MemRefType resultType = getResultType();
2259 
2260  if (srcType.getRank() >= resultType.getRank())
2261  return emitOpError("expected rank expansion, but found source rank ")
2262  << srcType.getRank() << " >= result rank " << resultType.getRank();
2263 
2264  // Verify result shape.
2265  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2266  resultType.getShape(),
2267  getReassociationIndices(),
2268  /*allowMultipleDynamicDimsPerGroup=*/false)))
2269  return failure();
2270 
2271  // Compute expected result type (including layout map).
2272  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2273  srcType, resultType.getShape(), getReassociationIndices());
2274  if (failed(expectedResultType))
2275  return emitOpError("invalid source layout map");
2276 
2277  // Check actual result type.
2278  if (*expectedResultType != resultType)
2279  return emitOpError("expected expanded type to be ")
2280  << *expectedResultType << " but found " << resultType;
2281 
2282  return success();
2283 }
2284 
2285 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2286  MLIRContext *context) {
2289  context);
2290 }
2291 
2292 /// Compute the layout map after collapsing a given source MemRef type with the
2293 /// specified reassociation indices.
2294 ///
2295 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2296 /// not possible to check this by inspecting a MemRefType in the general case.
2297 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2298 /// be valid (and thus accepted by this function) unless `strict = true`.
2300 computeCollapsedLayoutMap(MemRefType srcType,
2301  ArrayRef<ReassociationIndices> reassociation,
2302  bool strict = false) {
2303  int64_t srcOffset;
2304  SmallVector<int64_t> srcStrides;
2305  auto srcShape = srcType.getShape();
2306  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2307  return failure();
2308 
2309  // The result stride of a reassociation group is the stride of the last entry
2310  // of the reassociation. (TODO: Should be the minimum stride in the
2311  // reassociation because strides are not necessarily sorted. E.g., when using
2312  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2313  // strides are meaningless and could have any arbitrary value.
2314  SmallVector<int64_t> resultStrides;
2315  resultStrides.reserve(reassociation.size());
2316  for (const ReassociationIndices &reassoc : reassociation) {
2317  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2318  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2319  ref = ref.drop_back();
2320  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2321  resultStrides.push_back(srcStrides[ref.back()]);
2322  } else {
2323  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2324  // the corresponding stride may have to be skipped. (See above comment.)
2325  // Therefore, the result stride cannot be statically determined and must
2326  // be dynamic.
2327  resultStrides.push_back(ShapedType::kDynamic);
2328  }
2329  }
2330 
2331  // Validate that each reassociation group is contiguous.
2332  unsigned resultStrideIndex = resultStrides.size() - 1;
2333  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2334  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2335  using saturated_arith::Wrapper;
2336  auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
2337  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2338  stride = stride * Wrapper::size(srcShape[idx]);
2339 
2340  // Both source and result stride must have the same static value. In that
2341  // case, we can be sure, that the dimensions are collapsible (because they
2342  // are contiguous).
2343  // If `strict = false` (default during op verification), we accept cases
2344  // where one or both strides are dynamic. This is best effort: We reject
2345  // ops where obviously non-contiguous dims are collapsed, but accept ops
2346  // where we cannot be sure statically. Such ops may fail at runtime. See
2347  // the op documentation for details.
2348  auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
2349  if (strict && (stride.saturated || srcStride.saturated))
2350  return failure();
2351 
2352  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2353  return failure();
2354  }
2355  }
2356  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2357 }
2358 
2359 bool CollapseShapeOp::isGuaranteedCollapsible(
2360  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2361  // MemRefs with identity layout are always collapsible.
2362  if (srcType.getLayout().isIdentity())
2363  return true;
2364 
2365  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2366  /*strict=*/true));
2367 }
2368 
2369 MemRefType CollapseShapeOp::computeCollapsedType(
2370  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2371  SmallVector<int64_t> resultShape;
2372  resultShape.reserve(reassociation.size());
2373  for (const ReassociationIndices &group : reassociation) {
2374  using saturated_arith::Wrapper;
2375  auto groupSize = Wrapper::size(1);
2376  for (int64_t srcDim : group)
2377  groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
2378  resultShape.push_back(groupSize.asSize());
2379  }
2380 
2381  if (srcType.getLayout().isIdentity()) {
2382  // If the source is contiguous (i.e., no layout map specified), so is the
2383  // result.
2384  MemRefLayoutAttrInterface layout;
2385  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2386  srcType.getMemorySpace());
2387  }
2388 
2389  // Source may not be fully contiguous. Compute the layout map.
2390  // Note: Dimensions that are collapsed into a single dim are assumed to be
2391  // contiguous.
2392  FailureOr<StridedLayoutAttr> computedLayout =
2393  computeCollapsedLayoutMap(srcType, reassociation);
2394  assert(succeeded(computedLayout) &&
2395  "invalid source layout map or collapsing non-contiguous dims");
2396  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2397  srcType.getMemorySpace());
2398 }
2399 
2400 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2401  ArrayRef<ReassociationIndices> reassociation,
2402  ArrayRef<NamedAttribute> attrs) {
2403  auto srcType = llvm::cast<MemRefType>(src.getType());
2404  MemRefType resultType =
2405  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2406  build(b, result, resultType, src, attrs);
2408  getReassociationIndicesAttribute(b, reassociation));
2409 }
2410 
2412  MemRefType srcType = getSrcType();
2413  MemRefType resultType = getResultType();
2414 
2415  if (srcType.getRank() <= resultType.getRank())
2416  return emitOpError("expected rank reduction, but found source rank ")
2417  << srcType.getRank() << " <= result rank " << resultType.getRank();
2418 
2419  // Verify result shape.
2420  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2421  srcType.getShape(), getReassociationIndices(),
2422  /*allowMultipleDynamicDimsPerGroup=*/true)))
2423  return failure();
2424 
2425  // Compute expected result type (including layout map).
2426  MemRefType expectedResultType;
2427  if (srcType.getLayout().isIdentity()) {
2428  // If the source is contiguous (i.e., no layout map specified), so is the
2429  // result.
2430  MemRefLayoutAttrInterface layout;
2431  expectedResultType =
2432  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2433  srcType.getMemorySpace());
2434  } else {
2435  // Source may not be fully contiguous. Compute the layout map.
2436  // Note: Dimensions that are collapsed into a single dim are assumed to be
2437  // contiguous.
2438  FailureOr<StridedLayoutAttr> computedLayout =
2439  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2440  if (failed(computedLayout))
2441  return emitOpError(
2442  "invalid source layout map or collapsing non-contiguous dims");
2443  expectedResultType =
2444  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2445  *computedLayout, srcType.getMemorySpace());
2446  }
2447 
2448  if (expectedResultType != resultType)
2449  return emitOpError("expected collapsed type to be ")
2450  << expectedResultType << " but found " << resultType;
2451 
2452  return success();
2453 }
2454 
2456  : public OpRewritePattern<CollapseShapeOp> {
2457 public:
2459 
2460  LogicalResult matchAndRewrite(CollapseShapeOp op,
2461  PatternRewriter &rewriter) const override {
2462  auto cast = op.getOperand().getDefiningOp<CastOp>();
2463  if (!cast)
2464  return failure();
2465 
2466  if (!CastOp::canFoldIntoConsumerOp(cast))
2467  return failure();
2468 
2469  Type newResultType = CollapseShapeOp::computeCollapsedType(
2470  llvm::cast<MemRefType>(cast.getOperand().getType()),
2471  op.getReassociationIndices());
2472 
2473  if (newResultType == op.getResultType()) {
2474  rewriter.updateRootInPlace(
2475  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2476  } else {
2477  Value newOp = rewriter.create<CollapseShapeOp>(
2478  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2479  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2480  }
2481  return success();
2482  }
2483 };
2484 
2485 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2486  MLIRContext *context) {
2490 }
2491 
2492 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2493  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2494  adaptor.getOperands());
2495 }
2496 
2497 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2498  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2499  adaptor.getOperands());
2500 }
2501 
2502 //===----------------------------------------------------------------------===//
2503 // ReshapeOp
2504 //===----------------------------------------------------------------------===//
2505 
2506 void ReshapeOp::getAsmResultNames(
2507  function_ref<void(Value, StringRef)> setNameFn) {
2508  setNameFn(getResult(), "reshape");
2509 }
2510 
2512  Type operandType = getSource().getType();
2513  Type resultType = getResult().getType();
2514 
2515  Type operandElementType =
2516  llvm::cast<ShapedType>(operandType).getElementType();
2517  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2518  if (operandElementType != resultElementType)
2519  return emitOpError("element types of source and destination memref "
2520  "types should be the same");
2521 
2522  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2523  if (!operandMemRefType.getLayout().isIdentity())
2524  return emitOpError("source memref type should have identity affine map");
2525 
2526  int64_t shapeSize =
2527  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2528  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2529  if (resultMemRefType) {
2530  if (!resultMemRefType.getLayout().isIdentity())
2531  return emitOpError("result memref type should have identity affine map");
2532  if (shapeSize == ShapedType::kDynamic)
2533  return emitOpError("cannot use shape operand with dynamic length to "
2534  "reshape to statically-ranked memref type");
2535  if (shapeSize != resultMemRefType.getRank())
2536  return emitOpError(
2537  "length of shape operand differs from the result's memref rank");
2538  }
2539  return success();
2540 }
2541 
2542 //===----------------------------------------------------------------------===//
2543 // StoreOp
2544 //===----------------------------------------------------------------------===//
2545 
2547  if (getNumOperands() != 2 + getMemRefType().getRank())
2548  return emitOpError("store index operand count not equal to memref rank");
2549 
2550  return success();
2551 }
2552 
2553 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2554  SmallVectorImpl<OpFoldResult> &results) {
2555  /// store(memrefcast) -> store
2556  return foldMemRefCast(*this, getValueToStore());
2557 }
2558 
2559 //===----------------------------------------------------------------------===//
2560 // SubViewOp
2561 //===----------------------------------------------------------------------===//
2562 
2563 void SubViewOp::getAsmResultNames(
2564  function_ref<void(Value, StringRef)> setNameFn) {
2565  setNameFn(getResult(), "subview");
2566 }
2567 
2568 /// A subview result type can be fully inferred from the source type and the
2569 /// static representation of offsets, sizes and strides. Special sentinels
2570 /// encode the dynamic case.
2571 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2572  ArrayRef<int64_t> staticOffsets,
2573  ArrayRef<int64_t> staticSizes,
2574  ArrayRef<int64_t> staticStrides) {
2575  unsigned rank = sourceMemRefType.getRank();
2576  (void)rank;
2577  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2578  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2579  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2580 
2581  // Extract source offset and strides.
2582  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
2583 
2584  // Compute target offset whose value is:
2585  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2586  int64_t targetOffset = sourceOffset;
2587  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2588  auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2589  using saturated_arith::Wrapper;
2590  targetOffset =
2591  (Wrapper::offset(targetOffset) +
2592  Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
2593  .asOffset();
2594  }
2595 
2596  // Compute target stride whose value is:
2597  // `sourceStrides_i * staticStrides_i`.
2598  SmallVector<int64_t, 4> targetStrides;
2599  targetStrides.reserve(staticOffsets.size());
2600  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2601  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2602  using saturated_arith::Wrapper;
2603  targetStrides.push_back(
2604  (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
2605  .asStride());
2606  }
2607 
2608  // The type is now known.
2609  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2610  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2611  targetOffset, targetStrides),
2612  sourceMemRefType.getMemorySpace());
2613 }
2614 
2615 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2616  ArrayRef<OpFoldResult> offsets,
2617  ArrayRef<OpFoldResult> sizes,
2618  ArrayRef<OpFoldResult> strides) {
2619  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2620  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2621  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2622  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2623  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2624  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2625  staticSizes, staticStrides);
2626 }
2627 
2628 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2629  MemRefType sourceRankedTensorType,
2630  ArrayRef<int64_t> offsets,
2631  ArrayRef<int64_t> sizes,
2632  ArrayRef<int64_t> strides) {
2633  auto inferredType = llvm::cast<MemRefType>(
2634  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2635  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2636  "expected ");
2637  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2638  return inferredType;
2639 
2640  // Compute which dimensions are dropped.
2641  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2642  computeRankReductionMask(inferredType.getShape(), resultShape);
2643  assert(dimsToProject.has_value() && "invalid rank reduction");
2644 
2645  // Compute the layout and result type.
2646  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2647  SmallVector<int64_t> rankReducedStrides;
2648  rankReducedStrides.reserve(resultShape.size());
2649  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2650  if (!dimsToProject->contains(idx))
2651  rankReducedStrides.push_back(value);
2652  }
2653  return MemRefType::get(resultShape, inferredType.getElementType(),
2654  StridedLayoutAttr::get(inferredLayout.getContext(),
2655  inferredLayout.getOffset(),
2656  rankReducedStrides),
2657  inferredType.getMemorySpace());
2658 }
2659 
2660 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2661  MemRefType sourceRankedTensorType,
2662  ArrayRef<OpFoldResult> offsets,
2663  ArrayRef<OpFoldResult> sizes,
2664  ArrayRef<OpFoldResult> strides) {
2665  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2666  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2667  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2668  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2669  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2670  return SubViewOp::inferRankReducedResultType(
2671  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2672  staticStrides);
2673 }
2674 
2675 // Build a SubViewOp with mixed static and dynamic entries and custom result
2676 // type. If the type passed is nullptr, it is inferred.
2677 void SubViewOp::build(OpBuilder &b, OperationState &result,
2678  MemRefType resultType, Value source,
2679  ArrayRef<OpFoldResult> offsets,
2680  ArrayRef<OpFoldResult> sizes,
2681  ArrayRef<OpFoldResult> strides,
2682  ArrayRef<NamedAttribute> attrs) {
2683  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2684  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2685  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2686  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2687  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2688  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2689  // Structuring implementation this way avoids duplication between builders.
2690  if (!resultType) {
2691  resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2692  sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2693  }
2694  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2695  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2696  b.getDenseI64ArrayAttr(staticSizes),
2697  b.getDenseI64ArrayAttr(staticStrides));
2698  result.addAttributes(attrs);
2699 }
2700 
2701 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2702 // type.
2703 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2704  ArrayRef<OpFoldResult> offsets,
2705  ArrayRef<OpFoldResult> sizes,
2706  ArrayRef<OpFoldResult> strides,
2707  ArrayRef<NamedAttribute> attrs) {
2708  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2709 }
2710 
2711 // Build a SubViewOp with static entries and inferred result type.
2712 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2713  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2714  ArrayRef<int64_t> strides,
2715  ArrayRef<NamedAttribute> attrs) {
2716  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2717  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2718  return b.getI64IntegerAttr(v);
2719  }));
2720  SmallVector<OpFoldResult> sizeValues =
2721  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2722  return b.getI64IntegerAttr(v);
2723  }));
2724  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2725  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2726  return b.getI64IntegerAttr(v);
2727  }));
2728  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2729 }
2730 
2731 // Build a SubViewOp with dynamic entries and custom result type. If the
2732 // type passed is nullptr, it is inferred.
2733 void SubViewOp::build(OpBuilder &b, OperationState &result,
2734  MemRefType resultType, Value source,
2735  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2736  ArrayRef<int64_t> strides,
2737  ArrayRef<NamedAttribute> attrs) {
2738  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2739  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2740  return b.getI64IntegerAttr(v);
2741  }));
2742  SmallVector<OpFoldResult> sizeValues =
2743  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2744  return b.getI64IntegerAttr(v);
2745  }));
2746  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2747  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2748  return b.getI64IntegerAttr(v);
2749  }));
2750  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2751  attrs);
2752 }
2753 
2754 // Build a SubViewOp with dynamic entries and custom result type. If the type
2755 // passed is nullptr, it is inferred.
2756 void SubViewOp::build(OpBuilder &b, OperationState &result,
2757  MemRefType resultType, Value source, ValueRange offsets,
2758  ValueRange sizes, ValueRange strides,
2759  ArrayRef<NamedAttribute> attrs) {
2760  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2761  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2762  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2763  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2764  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2765  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2766  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2767 }
2768 
2769 // Build a SubViewOp with dynamic entries and inferred result type.
2770 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2771  ValueRange offsets, ValueRange sizes, ValueRange strides,
2772  ArrayRef<NamedAttribute> attrs) {
2773  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2774 }
2775 
2776 /// For ViewLikeOpInterface.
2777 Value SubViewOp::getViewSource() { return getSource(); }
2778 
2779 /// Return true if t1 and t2 have equal offsets (both dynamic or of same
2780 /// static value).
2781 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2782  int64_t t1Offset, t2Offset;
2783  SmallVector<int64_t> t1Strides, t2Strides;
2784  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2785  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2786  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2787 }
2788 
2789 /// Checks if `original` Type type can be rank reduced to `reduced` type.
2790 /// This function is slight variant of `is subsequence` algorithm where
2791 /// not matching dimension must be 1.
2793 isRankReducedMemRefType(MemRefType originalType,
2794  MemRefType candidateRankReducedType,
2795  ArrayRef<OpFoldResult> sizes) {
2796  auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
2797  if (partialRes != SliceVerificationResult::Success)
2798  return partialRes;
2799 
2800  auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
2801  originalType, candidateRankReducedType, sizes);
2802 
2803  // Sizes cannot be matched in case empty vector is returned.
2804  if (!optionalUnusedDimsMask)
2806 
2807  if (originalType.getMemorySpace() !=
2808  candidateRankReducedType.getMemorySpace())
2810 
2811  // No amount of stride dropping can reconcile incompatible offsets.
2812  if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
2814 
2816 }
2817 
2818 template <typename OpTy>
2820  OpTy op, Type expectedType) {
2821  auto memrefType = llvm::cast<ShapedType>(expectedType);
2822  switch (result) {
2824  return success();
2826  return op.emitError("expected result rank to be smaller or equal to ")
2827  << "the source rank. ";
2829  return op.emitError("expected result type to be ")
2830  << expectedType
2831  << " or a rank-reduced version. (mismatch of result sizes) ";
2833  return op.emitError("expected result element type to be ")
2834  << memrefType.getElementType();
2836  return op.emitError("expected result and source memory spaces to match.");
2838  return op.emitError("expected result type to be ")
2839  << expectedType
2840  << " or a rank-reduced version. (mismatch of result layout) ";
2841  }
2842  llvm_unreachable("unexpected subview verification result");
2843 }
2844 
2845 /// Verifier for SubViewOp.
2847  MemRefType baseType = getSourceType();
2848  MemRefType subViewType = getType();
2849 
2850  // The base memref and the view memref should be in the same memory space.
2851  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2852  return emitError("different memory spaces specified for base memref "
2853  "type ")
2854  << baseType << " and subview memref type " << subViewType;
2855 
2856  // Verify that the base memref type has a strided layout map.
2857  if (!isStrided(baseType))
2858  return emitError("base type ") << baseType << " is not strided";
2859 
2860  // Verify result type against inferred type.
2861  auto expectedType = SubViewOp::inferResultType(
2862  baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
2863 
2864  auto result = isRankReducedMemRefType(llvm::cast<MemRefType>(expectedType),
2865  subViewType, getMixedSizes());
2866  return produceSubViewErrorMsg(result, *this, expectedType);
2867 }
2868 
2869 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2870  return os << "range " << range.offset << ":" << range.size << ":"
2871  << range.stride;
2872 }
2873 
2874 /// Return the list of Range (i.e. offset, size, stride). Each Range
2875 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2876 /// with `b` at location `loc`.
2877 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2878  OpBuilder &b, Location loc) {
2879  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2880  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2881  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2883  unsigned rank = ranks[0];
2884  res.reserve(rank);
2885  for (unsigned idx = 0; idx < rank; ++idx) {
2886  Value offset =
2887  op.isDynamicOffset(idx)
2888  ? op.getDynamicOffset(idx)
2889  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
2890  Value size =
2891  op.isDynamicSize(idx)
2892  ? op.getDynamicSize(idx)
2893  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
2894  Value stride =
2895  op.isDynamicStride(idx)
2896  ? op.getDynamicStride(idx)
2897  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
2898  res.emplace_back(Range{offset, size, stride});
2899  }
2900  return res;
2901 }
2902 
2903 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2904 /// to deduce the result type for the given `sourceType`. Additionally, reduce
2905 /// the rank of the inferred result type if `currentResultType` is lower rank
2906 /// than `currentSourceType`. Use this signature if `sourceType` is updated
2907 /// together with the result type. In this case, it is important to compute
2908 /// the dropped dimensions using `currentSourceType` whose strides align with
2909 /// `currentResultType`.
2911  MemRefType currentResultType, MemRefType currentSourceType,
2912  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2913  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2914  auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2915  sourceType, mixedOffsets, mixedSizes, mixedStrides));
2916  std::optional<llvm::SmallBitVector> unusedDims =
2917  computeMemRefRankReductionMask(currentSourceType, currentResultType,
2918  mixedSizes);
2919  // Return nullptr as failure mode.
2920  if (!unusedDims)
2921  return nullptr;
2922 
2923  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
2924  SmallVector<int64_t> shape, strides;
2925  unsigned numDimsAfterReduction =
2926  nonRankReducedType.getRank() - unusedDims->count();
2927  shape.reserve(numDimsAfterReduction);
2928  strides.reserve(numDimsAfterReduction);
2929  for (const auto &[idx, size, stride] :
2930  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
2931  nonRankReducedType.getShape(), layout.getStrides())) {
2932  if (unusedDims->test(idx))
2933  continue;
2934  shape.push_back(size);
2935  strides.push_back(stride);
2936  }
2937 
2938  return MemRefType::get(shape, nonRankReducedType.getElementType(),
2939  StridedLayoutAttr::get(sourceType.getContext(),
2940  layout.getOffset(), strides),
2941  nonRankReducedType.getMemorySpace());
2942 }
2943 
2945  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
2946  auto memrefType = llvm::cast<MemRefType>(memref.getType());
2947  unsigned rank = memrefType.getRank();
2948  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2949  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
2950  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2951  auto targetType =
2952  llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
2953  targetShape, memrefType, offsets, sizes, strides));
2954  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
2955  sizes, strides);
2956 }
2957 
2958 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
2959  Value value,
2960  ArrayRef<int64_t> desiredShape) {
2961  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
2962  assert(sourceMemrefType && "not a ranked memref type");
2963  auto sourceShape = sourceMemrefType.getShape();
2964  if (sourceShape.equals(desiredShape))
2965  return value;
2966  auto maybeRankReductionMask =
2967  mlir::computeRankReductionMask(sourceShape, desiredShape);
2968  if (!maybeRankReductionMask)
2969  return failure();
2970  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
2971 }
2972 
2973 /// Helper method to check if a `subview` operation is trivially a no-op. This
2974 /// is the case if the all offsets are zero, all strides are 1, and the source
2975 /// shape is same as the size of the subview. In such cases, the subview can
2976 /// be folded into its source.
2977 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
2978  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2979  return false;
2980 
2981  auto mixedOffsets = subViewOp.getMixedOffsets();
2982  auto mixedSizes = subViewOp.getMixedSizes();
2983  auto mixedStrides = subViewOp.getMixedStrides();
2984 
2985  // Check offsets are zero.
2986  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
2987  std::optional<int64_t> intValue = getConstantIntValue(ofr);
2988  return !intValue || intValue.value() != 0;
2989  }))
2990  return false;
2991 
2992  // Check strides are one.
2993  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
2994  std::optional<int64_t> intValue = getConstantIntValue(ofr);
2995  return !intValue || intValue.value() != 1;
2996  }))
2997  return false;
2998 
2999  // Check all size values are static and matches the (static) source shape.
3000  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3001  for (const auto &size : llvm::enumerate(mixedSizes)) {
3002  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3003  if (!intValue || *intValue != sourceShape[size.index()])
3004  return false;
3005  }
3006  // All conditions met. The `SubViewOp` is foldable as a no-op.
3007  return true;
3008 }
3009 
3010 namespace {
3011 /// Pattern to rewrite a subview op with MemRefCast arguments.
3012 /// This essentially pushes memref.cast past its consuming subview when
3013 /// `canFoldIntoConsumerOp` is true.
3014 ///
3015 /// Example:
3016 /// ```
3017 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3018 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3019 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3020 /// ```
3021 /// is rewritten into:
3022 /// ```
3023 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3024 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3025 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3026 /// ```
3027 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3028 public:
3030 
3031  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3032  PatternRewriter &rewriter) const override {
3033  // Any constant operand, just return to let SubViewOpConstantFolder kick
3034  // in.
3035  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3036  return matchPattern(operand, matchConstantIndex());
3037  }))
3038  return failure();
3039 
3040  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3041  if (!castOp)
3042  return failure();
3043 
3044  if (!CastOp::canFoldIntoConsumerOp(castOp))
3045  return failure();
3046 
3047  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3048  // the MemRefCastOp source operand type to infer the result type and the
3049  // current SubViewOp source operand type to compute the dropped dimensions
3050  // if the operation is rank-reducing.
3051  auto resultType = getCanonicalSubViewResultType(
3052  subViewOp.getType(), subViewOp.getSourceType(),
3053  llvm::cast<MemRefType>(castOp.getSource().getType()),
3054  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3055  subViewOp.getMixedStrides());
3056  if (!resultType)
3057  return failure();
3058 
3059  Value newSubView = rewriter.create<SubViewOp>(
3060  subViewOp.getLoc(), resultType, castOp.getSource(),
3061  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3062  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3063  subViewOp.getStaticStrides());
3064  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3065  newSubView);
3066  return success();
3067  }
3068 };
3069 
3070 /// Canonicalize subview ops that are no-ops. When the source shape is not
3071 /// same as a result shape due to use of `affine_map`.
3072 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3073 public:
3075 
3076  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3077  PatternRewriter &rewriter) const override {
3078  if (!isTrivialSubViewOp(subViewOp))
3079  return failure();
3080  if (subViewOp.getSourceType() == subViewOp.getType()) {
3081  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3082  return success();
3083  }
3084  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3085  subViewOp.getSource());
3086  return success();
3087  }
3088 };
3089 } // namespace
3090 
3091 /// Return the canonical type of the result of a subview.
3093  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3094  ArrayRef<OpFoldResult> mixedSizes,
3095  ArrayRef<OpFoldResult> mixedStrides) {
3096  // Infer a memref type without taking into account any rank reductions.
3097  MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType(
3098  op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides));
3099 
3100  // Directly return the non-rank reduced type if there are no dropped dims.
3101  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3102  if (droppedDims.empty())
3103  return nonReducedType;
3104 
3105  // Take the strides and offset from the non-rank reduced type.
3106  auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
3107 
3108  // Drop dims from shape and strides.
3109  SmallVector<int64_t> targetShape;
3110  SmallVector<int64_t> targetStrides;
3111  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3112  if (droppedDims.test(i))
3113  continue;
3114  targetStrides.push_back(nonReducedStrides[i]);
3115  targetShape.push_back(nonReducedType.getDimSize(i));
3116  }
3117 
3118  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3119  StridedLayoutAttr::get(nonReducedType.getContext(),
3120  offset, targetStrides),
3121  nonReducedType.getMemorySpace());
3122  }
3123 };
3124 
3125 /// A canonicalizer wrapper to replace SubViewOps.
3127  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3128  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3129  }
3130 };
3131 
3132 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3133  MLIRContext *context) {
3134  results
3137  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3138 }
3139 
3140 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3141  auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
3142  auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
3143 
3144  if (resultShapedType.hasStaticShape() &&
3145  resultShapedType == sourceShapedType) {
3146  return getViewSource();
3147  }
3148 
3149  // Fold subview(subview(x)), where both subviews have the same size and the
3150  // second subview's offsets are all zero. (I.e., the second subview is a
3151  // no-op.)
3152  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3153  auto srcSizes = srcSubview.getMixedSizes();
3154  auto sizes = getMixedSizes();
3155  auto offsets = getMixedOffsets();
3156  bool allOffsetsZero = llvm::all_of(
3157  offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3158  auto strides = getMixedStrides();
3159  bool allStridesOne = llvm::all_of(
3160  strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3161  bool allSizesSame = llvm::equal(sizes, srcSizes);
3162  if (allOffsetsZero && allStridesOne && allSizesSame &&
3163  resultShapedType == sourceShapedType)
3164  return getViewSource();
3165  }
3166 
3167  return {};
3168 }
3169 
3170 //===----------------------------------------------------------------------===//
3171 // TransposeOp
3172 //===----------------------------------------------------------------------===//
3173 
3174 void TransposeOp::getAsmResultNames(
3175  function_ref<void(Value, StringRef)> setNameFn) {
3176  setNameFn(getResult(), "transpose");
3177 }
3178 
3179 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
3180 static MemRefType inferTransposeResultType(MemRefType memRefType,
3181  AffineMap permutationMap) {
3182  auto rank = memRefType.getRank();
3183  auto originalSizes = memRefType.getShape();
3184  auto [originalStrides, offset] = getStridesAndOffset(memRefType);
3185  assert(originalStrides.size() == static_cast<unsigned>(rank));
3186 
3187  // Compute permuted sizes and strides.
3188  SmallVector<int64_t> sizes(rank, 0);
3189  SmallVector<int64_t> strides(rank, 1);
3190  for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
3191  unsigned position = en.value().cast<AffineDimExpr>().getPosition();
3192  sizes[en.index()] = originalSizes[position];
3193  strides[en.index()] = originalStrides[position];
3194  }
3195 
3196  return MemRefType::Builder(memRefType)
3197  .setShape(sizes)
3198  .setLayout(
3199  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3200 }
3201 
3202 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3203  AffineMapAttr permutation,
3204  ArrayRef<NamedAttribute> attrs) {
3205  auto permutationMap = permutation.getValue();
3206  assert(permutationMap);
3207 
3208  auto memRefType = llvm::cast<MemRefType>(in.getType());
3209  // Compute result type.
3210  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3211 
3212  build(b, result, resultType, in, attrs);
3213  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3214 }
3215 
3216 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3218  p << " " << getIn() << " " << getPermutation();
3219  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3220  p << " : " << getIn().getType() << " to " << getType();
3221 }
3222 
3223 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3225  AffineMap permutation;
3226  MemRefType srcType, dstType;
3227  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3228  parser.parseOptionalAttrDict(result.attributes) ||
3229  parser.parseColonType(srcType) ||
3230  parser.resolveOperand(in, srcType, result.operands) ||
3231  parser.parseKeywordType("to", dstType) ||
3232  parser.addTypeToList(dstType, result.types))
3233  return failure();
3234 
3235  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3236  AffineMapAttr::get(permutation));
3237  return success();
3238 }
3239 
3241  if (!getPermutation().isPermutation())
3242  return emitOpError("expected a permutation map");
3243  if (getPermutation().getNumDims() != getIn().getType().getRank())
3244  return emitOpError("expected a permutation map of same rank as the input");
3245 
3246  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3247  auto dstType = llvm::cast<MemRefType>(getType());
3248  auto transposedType = inferTransposeResultType(srcType, getPermutation());
3249  if (dstType != transposedType)
3250  return emitOpError("output type ")
3251  << dstType << " does not match transposed input type " << srcType
3252  << ", " << transposedType;
3253  return success();
3254 }
3255 
3256 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3257  if (succeeded(foldMemRefCast(*this)))
3258  return getResult();
3259  return {};
3260 }
3261 
3262 //===----------------------------------------------------------------------===//
3263 // ViewOp
3264 //===----------------------------------------------------------------------===//
3265 
3266 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3267  setNameFn(getResult(), "view");
3268 }
3269 
3271  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3272  auto viewType = getType();
3273 
3274  // The base memref should have identity layout map (or none).
3275  if (!baseType.getLayout().isIdentity())
3276  return emitError("unsupported map for base memref type ") << baseType;
3277 
3278  // The result memref should have identity layout map (or none).
3279  if (!viewType.getLayout().isIdentity())
3280  return emitError("unsupported map for result memref type ") << viewType;
3281 
3282  // The base memref and the view memref should be in the same memory space.
3283  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3284  return emitError("different memory spaces specified for base memref "
3285  "type ")
3286  << baseType << " and view memref type " << viewType;
3287 
3288  // Verify that we have the correct number of sizes for the result type.
3289  unsigned numDynamicDims = viewType.getNumDynamicDims();
3290  if (getSizes().size() != numDynamicDims)
3291  return emitError("incorrect number of size operands for type ") << viewType;
3292 
3293  return success();
3294 }
3295 
3296 Value ViewOp::getViewSource() { return getSource(); }
3297 
3298 namespace {
3299 
3300 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3302 
3303  LogicalResult matchAndRewrite(ViewOp viewOp,
3304  PatternRewriter &rewriter) const override {
3305  // Return if none of the operands are constants.
3306  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3307  return matchPattern(operand, matchConstantIndex());
3308  }))
3309  return failure();
3310 
3311  // Get result memref type.
3312  auto memrefType = viewOp.getType();
3313 
3314  // Get offset from old memref view type 'memRefType'.
3315  int64_t oldOffset;
3316  SmallVector<int64_t, 4> oldStrides;
3317  if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3318  return failure();
3319  assert(oldOffset == 0 && "Expected 0 offset");
3320 
3321  SmallVector<Value, 4> newOperands;
3322 
3323  // Offset cannot be folded into result type.
3324 
3325  // Fold any dynamic dim operands which are produced by a constant.
3326  SmallVector<int64_t, 4> newShapeConstants;
3327  newShapeConstants.reserve(memrefType.getRank());
3328 
3329  unsigned dynamicDimPos = 0;
3330  unsigned rank = memrefType.getRank();
3331  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3332  int64_t dimSize = memrefType.getDimSize(dim);
3333  // If this is already static dimension, keep it.
3334  if (!ShapedType::isDynamic(dimSize)) {
3335  newShapeConstants.push_back(dimSize);
3336  continue;
3337  }
3338  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3339  if (auto constantIndexOp =
3340  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3341  // Dynamic shape dimension will be folded.
3342  newShapeConstants.push_back(constantIndexOp.value());
3343  } else {
3344  // Dynamic shape dimension not folded; copy operand from old memref.
3345  newShapeConstants.push_back(dimSize);
3346  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3347  }
3348  dynamicDimPos++;
3349  }
3350 
3351  // Create new memref type with constant folded dims.
3352  MemRefType newMemRefType =
3353  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3354  // Nothing new, don't fold.
3355  if (newMemRefType == memrefType)
3356  return failure();
3357 
3358  // Create new ViewOp.
3359  auto newViewOp = rewriter.create<ViewOp>(
3360  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3361  viewOp.getByteShift(), newOperands);
3362  // Insert a cast so we have the same type as the old memref type.
3363  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3364  return success();
3365  }
3366 };
3367 
3368 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3370 
3371  LogicalResult matchAndRewrite(ViewOp viewOp,
3372  PatternRewriter &rewriter) const override {
3373  Value memrefOperand = viewOp.getOperand(0);
3374  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3375  if (!memrefCastOp)
3376  return failure();
3377  Value allocOperand = memrefCastOp.getOperand();
3378  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3379  if (!allocOp)
3380  return failure();
3381  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3382  viewOp.getByteShift(),
3383  viewOp.getSizes());
3384  return success();
3385  }
3386 };
3387 
3388 } // namespace
3389 
3390 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3391  MLIRContext *context) {
3392  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3393 }
3394 
3395 //===----------------------------------------------------------------------===//
3396 // AtomicRMWOp
3397 //===----------------------------------------------------------------------===//
3398 
3400  if (getMemRefType().getRank() != getNumOperands() - 2)
3401  return emitOpError(
3402  "expects the number of subscripts to be equal to memref rank");
3403  switch (getKind()) {
3404  case arith::AtomicRMWKind::addf:
3405  case arith::AtomicRMWKind::maximumf:
3406  case arith::AtomicRMWKind::minimumf:
3407  case arith::AtomicRMWKind::mulf:
3408  if (!llvm::isa<FloatType>(getValue().getType()))
3409  return emitOpError() << "with kind '"
3410  << arith::stringifyAtomicRMWKind(getKind())
3411  << "' expects a floating-point type";
3412  break;
3413  case arith::AtomicRMWKind::addi:
3414  case arith::AtomicRMWKind::maxs:
3415  case arith::AtomicRMWKind::maxu:
3416  case arith::AtomicRMWKind::mins:
3417  case arith::AtomicRMWKind::minu:
3418  case arith::AtomicRMWKind::muli:
3419  case arith::AtomicRMWKind::ori:
3420  case arith::AtomicRMWKind::andi:
3421  if (!llvm::isa<IntegerType>(getValue().getType()))
3422  return emitOpError() << "with kind '"
3423  << arith::stringifyAtomicRMWKind(getKind())
3424  << "' expects an integer type";
3425  break;
3426  default:
3427  break;
3428  }
3429  return success();
3430 }
3431 
3432 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3433  /// atomicrmw(memrefcast) -> atomicrmw
3434  if (succeeded(foldMemRefCast(*this, getValue())))
3435  return getResult();
3436  return OpFoldResult();
3437 }
3438 
3439 //===----------------------------------------------------------------------===//
3440 // TableGen'd op method definitions
3441 //===----------------------------------------------------------------------===//
3442 
3443 #define GET_OP_CLASSES
3444 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:58
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
Definition: MemRefOps.cpp:200
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
Definition: MemRefOps.cpp:152
static SliceVerificationResult isRankReducedMemRefType(MemRefType originalType, MemRefType candidateRankReducedType, ArrayRef< OpFoldResult > sizes)
Checks if original Type type can be rank reduced to reduced type.
Definition: MemRefOps.cpp:2793
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1539
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2091
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:488
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap tp memRefType.
Definition: MemRefOps.cpp:3180
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:469
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2781
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1387
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:2910
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType)
Definition: MemRefOps.cpp:2819
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
Definition: MemRefOps.cpp:213
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1553
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
Definition: MemRefOps.cpp:192
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:2977
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:511
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:942
static std::optional< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:957
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2184
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2300
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:238
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1333
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:130
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:350
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:68
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
IndexType getIndexType()
Definition: Builders.cpp:71
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:200
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:221
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:211
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:206
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:261
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:226
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:248
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:96
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:104
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:81
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:114
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:2944
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:350
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:275
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:398
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:489
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:2877
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs)
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
AffineExpr operator+(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:244
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:245
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Idiomatic saturated operations on offsets, sizes and strides.
Definition: MemRefOps.cpp:31
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:558
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:561
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:518
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:521
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2460
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3126
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3127
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3092
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3093
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset