MLIR  16.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 namespace {
30 /// Idiomatic saturated operations on offsets, sizes and strides.
31 namespace saturated_arith {
32 struct Wrapper {
33  static Wrapper stride(int64_t v) {
34  return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
35  : Wrapper{false, v};
36  }
37  static Wrapper offset(int64_t v) {
38  return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
39  : Wrapper{false, v};
40  }
41  static Wrapper size(int64_t v) {
42  return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
43  }
44  int64_t asOffset() {
45  return saturated ? ShapedType::kDynamic : v;
46  }
47  int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
48  int64_t asStride() {
49  return saturated ? ShapedType::kDynamic : v;
50  }
51  bool operator==(Wrapper other) {
52  return (saturated && other.saturated) ||
53  (!saturated && !other.saturated && v == other.v);
54  }
55  bool operator!=(Wrapper other) { return !(*this == other); }
56  Wrapper operator+(Wrapper other) {
57  if (saturated || other.saturated)
58  return Wrapper{true, 0};
59  return Wrapper{false, other.v + v};
60  }
61  Wrapper operator*(Wrapper other) {
62  if (saturated || other.saturated)
63  return Wrapper{true, 0};
64  return Wrapper{false, other.v * v};
65  }
66  bool saturated;
67  int64_t v;
68 };
69 } // namespace saturated_arith
70 } // namespace
71 
72 /// Materialize a single constant operation from a given attribute value with
73 /// the desired resultant type.
75  Attribute value, Type type,
76  Location loc) {
77  if (arith::ConstantOp::isBuildableWith(value, type))
78  return builder.create<arith::ConstantOp>(loc, value, type);
79  return nullptr;
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Common canonicalization pattern support logic
84 //===----------------------------------------------------------------------===//
85 
86 /// This is a common class used for patterns of the form
87 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
88 /// into the root operation directly.
90  bool folded = false;
91  for (OpOperand &operand : op->getOpOperands()) {
92  auto cast = operand.get().getDefiningOp<CastOp>();
93  if (cast && operand.get() != inner &&
94  !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
95  operand.set(cast.getOperand());
96  folded = true;
97  }
98  }
99  return success(folded);
100 }
101 
102 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
103 /// type.
105  if (auto memref = type.dyn_cast<MemRefType>())
106  return RankedTensorType::get(memref.getShape(), memref.getElementType());
107  if (auto memref = type.dyn_cast<UnrankedMemRefType>())
108  return UnrankedTensorType::get(memref.getElementType());
109  return NoneType::get(type.getContext());
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Utility functions for propagating static information
114 //===----------------------------------------------------------------------===//
115 
116 /// Helper function that infers the constant values from a list of \p values,
117 /// a \p memRefTy, and another helper function \p getAttributes.
118 /// The inferred constant values replace the related `OpFoldResult` in
119 /// \p values.
120 ///
121 /// \note This function shouldn't be used directly, instead, use the
122 /// `getConstifiedMixedXXX` methods from the related operations.
123 ///
124 /// \p getAttributes retuns a list of potentially constant values, as determined
125 /// by \p isDynamic, from the given \p memRefTy. The returned list must have as
126 /// many elements as \p values or be empty.
127 ///
128 /// E.g., consider the following example:
129 /// ```
130 /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
131 /// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
132 /// ```
133 /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
134 /// Now using this helper function with:
135 /// - `values == [2, %dyn_stride]`,
136 /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
137 /// - `getAttributes == getConstantStrides` (i.e., a wrapper around
138 /// `getStridesAndOffset`), and
139 /// - `isDynamic == ShapedType::isDynamic`
140 /// Will yield: `values == [2, 1]`
142  SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
143  MLIRContext *ctxt,
144  llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
145  llvm::function_ref<bool(int64_t)> isDynamic) {
146  SmallVector<int64_t> constValues = getAttributes(memRefTy);
147  Builder builder(ctxt);
148  for (const auto &it : llvm::enumerate(constValues)) {
149  int64_t constValue = it.value();
150  if (!isDynamic(constValue))
151  values[it.index()] = builder.getIndexAttr(constValue);
152  }
153  for (OpFoldResult &ofr : values) {
154  if (ofr.is<Attribute>()) {
155  // FIXME: We shouldn't need to do that, but right now, the static indices
156  // are created with the wrong type: `i64` instead of `index`.
157  // As a result, if we were to keep the attribute as is, we may fail to see
158  // that two attributes are equal because one would have the i64 type and
159  // the other the index type.
160  // The alternative would be to create constant indices with getI64Attr in
161  // this and the previous loop, but it doesn't logically make sense (we are
162  // dealing with indices here) and would only strenghten the inconsistency
163  // around how static indices are created (some places use getI64Attr,
164  // others use getIndexAttr).
165  // The workaround here is to stick to the IndexAttr type for all the
166  // values, hence we recreate the attribute even when it is already static
167  // to make sure the type is consistent.
168  ofr = builder.getIndexAttr(
169  ofr.get<Attribute>().cast<IntegerAttr>().getInt());
170  continue;
171  }
172  Optional<int64_t> maybeConstant = getConstantIntValue(ofr.get<Value>());
173  if (maybeConstant)
174  ofr = builder.getIndexAttr(*maybeConstant);
175  }
176 }
177 
178 /// Wrapper around `getShape` that conforms to the function signature
179 /// expected for `getAttributes` in `constifyIndexValues`.
180 static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
181  ArrayRef<int64_t> sizes = memRefTy.getShape();
182  return SmallVector<int64_t>(sizes.begin(), sizes.end());
183 }
184 
185 /// Wrapper around `getStridesAndOffset` that returns only the offset and
186 /// conforms to the function signature expected for `getAttributes` in
187 /// `constifyIndexValues`.
188 static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
189  SmallVector<int64_t> strides;
190  int64_t offset;
191  LogicalResult hasStaticInformation =
192  getStridesAndOffset(memrefType, strides, offset);
193  if (failed(hasStaticInformation))
194  return SmallVector<int64_t>();
195  return SmallVector<int64_t>(1, offset);
196 }
197 
198 /// Wrapper around `getStridesAndOffset` that returns only the strides and
199 /// conforms to the function signature expected for `getAttributes` in
200 /// `constifyIndexValues`.
201 static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
202  SmallVector<int64_t> strides;
203  int64_t offset;
204  LogicalResult hasStaticInformation =
205  getStridesAndOffset(memrefType, strides, offset);
206  if (failed(hasStaticInformation))
207  return SmallVector<int64_t>();
208  return strides;
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // AllocOp / AllocaOp
213 //===----------------------------------------------------------------------===//
214 
215 void AllocOp::getAsmResultNames(
216  function_ref<void(Value, StringRef)> setNameFn) {
217  setNameFn(getResult(), "alloc");
218 }
219 
220 void AllocaOp::getAsmResultNames(
221  function_ref<void(Value, StringRef)> setNameFn) {
222  setNameFn(getResult(), "alloca");
223 }
224 
225 template <typename AllocLikeOp>
226 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
228  "applies to only alloc or alloca");
229  auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
230  if (!memRefType)
231  return op.emitOpError("result must be a memref");
232 
233  if (static_cast<int64_t>(op.getDynamicSizes().size()) !=
234  memRefType.getNumDynamicDims())
235  return op.emitOpError("dimension operand count does not equal memref "
236  "dynamic dimension count");
237 
238  unsigned numSymbols = 0;
239  if (!memRefType.getLayout().isIdentity())
240  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
241  if (op.getSymbolOperands().size() != numSymbols)
242  return op.emitOpError("symbol operand count does not equal memref symbol "
243  "count: expected ")
244  << numSymbols << ", got " << op.getSymbolOperands().size();
245 
246  return success();
247 }
248 
250 
252  // An alloca op needs to have an ancestor with an allocation scope trait.
253  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
254  return emitOpError(
255  "requires an ancestor op with AutomaticAllocationScope trait");
256 
257  return verifyAllocLikeOp(*this);
258 }
259 
260 namespace {
261 /// Fold constant dimensions into an alloc like operation.
262 template <typename AllocLikeOp>
263 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
265 
266  LogicalResult matchAndRewrite(AllocLikeOp alloc,
267  PatternRewriter &rewriter) const override {
268  // Check to see if any dimensions operands are constants. If so, we can
269  // substitute and drop them.
270  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
271  return matchPattern(operand, matchConstantIndex());
272  }))
273  return failure();
274 
275  auto memrefType = alloc.getType();
276 
277  // Ok, we have one or more constant operands. Collect the non-constant ones
278  // and keep track of the resultant memref type to build.
279  SmallVector<int64_t, 4> newShapeConstants;
280  newShapeConstants.reserve(memrefType.getRank());
281  SmallVector<Value, 4> dynamicSizes;
282 
283  unsigned dynamicDimPos = 0;
284  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
285  int64_t dimSize = memrefType.getDimSize(dim);
286  // If this is already static dimension, keep it.
287  if (!ShapedType::isDynamic(dimSize)) {
288  newShapeConstants.push_back(dimSize);
289  continue;
290  }
291  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
292  auto *defOp = dynamicSize.getDefiningOp();
293  if (auto constantIndexOp =
294  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
295  // Dynamic shape dimension will be folded.
296  newShapeConstants.push_back(constantIndexOp.value());
297  } else {
298  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
299  newShapeConstants.push_back(ShapedType::kDynamic);
300  dynamicSizes.push_back(dynamicSize);
301  }
302  dynamicDimPos++;
303  }
304 
305  // Create new memref type (which will have fewer dynamic dimensions).
306  MemRefType newMemRefType =
307  MemRefType::Builder(memrefType).setShape(newShapeConstants);
308  assert(static_cast<int64_t>(dynamicSizes.size()) ==
309  newMemRefType.getNumDynamicDims());
310 
311  // Create and insert the alloc op for the new memref.
312  auto newAlloc = rewriter.create<AllocLikeOp>(
313  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
314  alloc.getAlignmentAttr());
315  // Insert a cast so we have the same type as the old alloc.
316  auto resultCast =
317  rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
318 
319  rewriter.replaceOp(alloc, {resultCast});
320  return success();
321  }
322 };
323 
324 /// Fold alloc operations with no users or only store and dealloc uses.
325 template <typename T>
326 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
328 
329  LogicalResult matchAndRewrite(T alloc,
330  PatternRewriter &rewriter) const override {
331  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
332  if (auto storeOp = dyn_cast<StoreOp>(op))
333  return storeOp.getValue() == alloc;
334  return !isa<DeallocOp>(op);
335  }))
336  return failure();
337 
338  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
339  rewriter.eraseOp(user);
340 
341  rewriter.eraseOp(alloc);
342  return success();
343  }
344 };
345 } // namespace
346 
347 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
348  MLIRContext *context) {
349  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
350 }
351 
352 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
353  MLIRContext *context) {
354  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
355  context);
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // ReallocOp
360 //===----------------------------------------------------------------------===//
361 
363  auto sourceType = getOperand(0).getType().cast<MemRefType>();
364  MemRefType resultType = getType();
365 
366  // The source memref should have identity layout (or none).
367  if (!sourceType.getLayout().isIdentity())
368  return emitError("unsupported layout for source memref type ")
369  << sourceType;
370 
371  // The result memref should have identity layout (or none).
372  if (!resultType.getLayout().isIdentity())
373  return emitError("unsupported layout for result memref type ")
374  << resultType;
375 
376  // The source memref and the result memref should be in the same memory space.
377  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
378  return emitError("different memory spaces specified for source memref "
379  "type ")
380  << sourceType << " and result memref type " << resultType;
381 
382  // The source memref and the result memref should have the same element type.
383  if (sourceType.getElementType() != resultType.getElementType())
384  return emitError("different element types specified for source memref "
385  "type ")
386  << sourceType << " and result memref type " << resultType;
387 
388  // Verify that we have the dynamic dimension operand when it is needed.
389  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
390  return emitError("missing dimension operand for result type ")
391  << resultType;
392  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
393  return emitError("unnecessary dimension operand for result type ")
394  << resultType;
395 
396  return success();
397 }
398 
399 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
400  MLIRContext *context) {
401  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
402 }
403 
404 //===----------------------------------------------------------------------===//
405 // AllocaScopeOp
406 //===----------------------------------------------------------------------===//
407 
409  bool printBlockTerminators = false;
410 
411  p << ' ';
412  if (!getResults().empty()) {
413  p << " -> (" << getResultTypes() << ")";
414  printBlockTerminators = true;
415  }
416  p << ' ';
417  p.printRegion(getBodyRegion(),
418  /*printEntryBlockArgs=*/false,
419  /*printBlockTerminators=*/printBlockTerminators);
420  p.printOptionalAttrDict((*this)->getAttrs());
421 }
422 
423 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
424  // Create a region for the body.
425  result.regions.reserve(1);
426  Region *bodyRegion = result.addRegion();
427 
428  // Parse optional results type list.
429  if (parser.parseOptionalArrowTypeList(result.types))
430  return failure();
431 
432  // Parse the body region.
433  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
434  return failure();
435  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
436  result.location);
437 
438  // Parse the optional attribute list.
439  if (parser.parseOptionalAttrDict(result.attributes))
440  return failure();
441 
442  return success();
443 }
444 
445 void AllocaScopeOp::getSuccessorRegions(
446  Optional<unsigned> index, ArrayRef<Attribute> operands,
448  if (index) {
449  regions.push_back(RegionSuccessor(getResults()));
450  return;
451  }
452 
453  regions.push_back(RegionSuccessor(&getBodyRegion()));
454 }
455 
456 /// Given an operation, return whether this op is guaranteed to
457 /// allocate an AutomaticAllocationScopeResource
459  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
460  if (!interface)
461  return false;
462  for (auto res : op->getResults()) {
463  if (auto effect =
464  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
465  if (isa<SideEffects::AutomaticAllocationScopeResource>(
466  effect->getResource()))
467  return true;
468  }
469  }
470  return false;
471 }
472 
473 /// Given an operation, return whether this op itself could
474 /// allocate an AutomaticAllocationScopeResource. Note that
475 /// this will not check whether an operation contained within
476 /// the op can allocate.
478  // This op itself doesn't create a stack allocation,
479  // the inner allocation should be handled separately.
481  return false;
482  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
483  if (!interface)
484  return true;
485  for (auto res : op->getResults()) {
486  if (auto effect =
487  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
488  if (isa<SideEffects::AutomaticAllocationScopeResource>(
489  effect->getResource()))
490  return true;
491  }
492  }
493  return false;
494 }
495 
496 /// Return whether this op is the last non terminating op
497 /// in a region. That is to say, it is in a one-block region
498 /// and is only followed by a terminator. This prevents
499 /// extending the lifetime of allocations.
501  return op->getNextNode() == op->getBlock()->getTerminator() &&
502  op->getParentRegion()->getBlocks().size() == 1;
503 }
504 
505 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
506 /// or it contains no allocation.
507 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
509 
510  LogicalResult matchAndRewrite(AllocaScopeOp op,
511  PatternRewriter &rewriter) const override {
512  bool hasPotentialAlloca =
513  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
514  if (alloc == op)
515  return WalkResult::advance();
517  return WalkResult::interrupt();
518  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
519  return WalkResult::skip();
520  return WalkResult::advance();
521  }).wasInterrupted();
522 
523  // If this contains no potential allocation, it is always legal to
524  // inline. Otherwise, consider two conditions:
525  if (hasPotentialAlloca) {
526  // If the parent isn't an allocation scope, or we are not the last
527  // non-terminator op in the parent, we will extend the lifetime.
528  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
529  return failure();
530  if (!lastNonTerminatorInRegion(op))
531  return failure();
532  }
533 
534  Block *block = &op.getRegion().front();
535  Operation *terminator = block->getTerminator();
536  ValueRange results = terminator->getOperands();
537  rewriter.mergeBlockBefore(block, op);
538  rewriter.replaceOp(op, results);
539  rewriter.eraseOp(terminator);
540  return success();
541  }
542 };
543 
544 /// Move allocations into an allocation scope, if it is legal to
545 /// move them (e.g. their operands are available at the location
546 /// the op would be moved to).
547 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
549 
550  LogicalResult matchAndRewrite(AllocaScopeOp op,
551  PatternRewriter &rewriter) const override {
552 
553  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
554  return failure();
555 
556  Operation *lastParentWithoutScope = op->getParentOp();
557 
558  if (!lastParentWithoutScope ||
559  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
560  return failure();
561 
562  // Only apply to if this is this last non-terminator
563  // op in the block (lest lifetime be extended) of a one
564  // block region
565  if (!lastNonTerminatorInRegion(op) ||
566  !lastNonTerminatorInRegion(lastParentWithoutScope))
567  return failure();
568 
569  while (!lastParentWithoutScope->getParentOp()
571  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
572  if (!lastParentWithoutScope ||
573  !lastNonTerminatorInRegion(lastParentWithoutScope))
574  return failure();
575  }
576  assert(lastParentWithoutScope->getParentOp()
578 
579  Region *containingRegion = nullptr;
580  for (auto &r : lastParentWithoutScope->getRegions()) {
581  if (r.isAncestor(op->getParentRegion())) {
582  assert(containingRegion == nullptr &&
583  "only one region can contain the op");
584  containingRegion = &r;
585  }
586  }
587  assert(containingRegion && "op must be contained in a region");
588 
589  SmallVector<Operation *> toHoist;
590  op->walk([&](Operation *alloc) {
592  return WalkResult::skip();
593 
594  // If any operand is not defined before the location of
595  // lastParentWithoutScope (i.e. where we would hoist to), skip.
596  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
597  return containingRegion->isAncestor(v.getParentRegion());
598  }))
599  return WalkResult::skip();
600  toHoist.push_back(alloc);
601  return WalkResult::advance();
602  });
603 
604  if (toHoist.empty())
605  return failure();
606  rewriter.setInsertionPoint(lastParentWithoutScope);
607  for (auto *op : toHoist) {
608  auto *cloned = rewriter.clone(*op);
609  rewriter.replaceOp(op, cloned->getResults());
610  }
611  return success();
612  }
613 };
614 
615 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
616  MLIRContext *context) {
617  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // AssumeAlignmentOp
622 //===----------------------------------------------------------------------===//
623 
625  if (!llvm::isPowerOf2_32(getAlignment()))
626  return emitOpError("alignment must be power of 2");
627  return success();
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // CastOp
632 //===----------------------------------------------------------------------===//
633 
634 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
635  setNameFn(getResult(), "cast");
636 }
637 
638 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
639 /// source memref. This is useful to to fold a memref.cast into a consuming op
640 /// and implement canonicalization patterns for ops in different dialects that
641 /// may consume the results of memref.cast operations. Such foldable memref.cast
642 /// operations are typically inserted as `view` and `subview` ops are
643 /// canonicalized, to preserve the type compatibility of their uses.
644 ///
645 /// Returns true when all conditions are met:
646 /// 1. source and result are ranked memrefs with strided semantics and same
647 /// element type and rank.
648 /// 2. each of the source's size, offset or stride has more static information
649 /// than the corresponding result's size, offset or stride.
650 ///
651 /// Example 1:
652 /// ```mlir
653 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
654 /// %2 = consumer %1 ... : memref<?x?xf32> ...
655 /// ```
656 ///
657 /// may fold into:
658 ///
659 /// ```mlir
660 /// %2 = consumer %0 ... : memref<8x16xf32> ...
661 /// ```
662 ///
663 /// Example 2:
664 /// ```
665 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
666 /// to memref<?x?xf32>
667 /// consumer %1 : memref<?x?xf32> ...
668 /// ```
669 ///
670 /// may fold into:
671 ///
672 /// ```
673 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
674 /// ```
675 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
676  MemRefType sourceType = castOp.getSource().getType().dyn_cast<MemRefType>();
677  MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
678 
679  // Requires ranked MemRefType.
680  if (!sourceType || !resultType)
681  return false;
682 
683  // Requires same elemental type.
684  if (sourceType.getElementType() != resultType.getElementType())
685  return false;
686 
687  // Requires same rank.
688  if (sourceType.getRank() != resultType.getRank())
689  return false;
690 
691  // Only fold casts between strided memref forms.
692  int64_t sourceOffset, resultOffset;
693  SmallVector<int64_t, 4> sourceStrides, resultStrides;
694  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
695  failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
696  return false;
697 
698  // If cast is towards more static sizes along any dimension, don't fold.
699  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
700  auto ss = std::get<0>(it), st = std::get<1>(it);
701  if (ss != st)
702  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
703  return false;
704  }
705 
706  // If cast is towards more static offset along any dimension, don't fold.
707  if (sourceOffset != resultOffset)
708  if (ShapedType::isDynamic(sourceOffset) &&
709  !ShapedType::isDynamic(resultOffset))
710  return false;
711 
712  // If cast is towards more static strides along any dimension, don't fold.
713  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
714  auto ss = std::get<0>(it), st = std::get<1>(it);
715  if (ss != st)
716  if (ShapedType::isDynamic(ss) &&
717  !ShapedType::isDynamic(st))
718  return false;
719  }
720 
721  return true;
722 }
723 
724 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
725  if (inputs.size() != 1 || outputs.size() != 1)
726  return false;
727  Type a = inputs.front(), b = outputs.front();
728  auto aT = a.dyn_cast<MemRefType>();
729  auto bT = b.dyn_cast<MemRefType>();
730 
731  auto uaT = a.dyn_cast<UnrankedMemRefType>();
732  auto ubT = b.dyn_cast<UnrankedMemRefType>();
733 
734  if (aT && bT) {
735  if (aT.getElementType() != bT.getElementType())
736  return false;
737  if (aT.getLayout() != bT.getLayout()) {
738  int64_t aOffset, bOffset;
739  SmallVector<int64_t, 4> aStrides, bStrides;
740  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
741  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
742  aStrides.size() != bStrides.size())
743  return false;
744 
745  // Strides along a dimension/offset are compatible if the value in the
746  // source memref is static and the value in the target memref is the
747  // same. They are also compatible if either one is dynamic (see
748  // description of MemRefCastOp for details).
749  auto checkCompatible = [](int64_t a, int64_t b) {
750  return (ShapedType::isDynamic(a) ||
751  ShapedType::isDynamic(b) || a == b);
752  };
753  if (!checkCompatible(aOffset, bOffset))
754  return false;
755  for (const auto &aStride : enumerate(aStrides))
756  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
757  return false;
758  }
759  if (aT.getMemorySpace() != bT.getMemorySpace())
760  return false;
761 
762  // They must have the same rank, and any specified dimensions must match.
763  if (aT.getRank() != bT.getRank())
764  return false;
765 
766  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
767  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
768  if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
769  aDim != bDim)
770  return false;
771  }
772  return true;
773  } else {
774  if (!aT && !uaT)
775  return false;
776  if (!bT && !ubT)
777  return false;
778  // Unranked to unranked casting is unsupported
779  if (uaT && ubT)
780  return false;
781 
782  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
783  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
784  if (aEltType != bEltType)
785  return false;
786 
787  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
788  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
789  return aMemSpace == bMemSpace;
790  }
791 
792  return false;
793 }
794 
795 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
796  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
797 }
798 
799 //===----------------------------------------------------------------------===//
800 // CopyOp
801 //===----------------------------------------------------------------------===//
802 
803 namespace {
804 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
805 /// and element type, the cast can be skipped. Such CastOps only cast the layout
806 /// of the type.
807 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
809 
810  LogicalResult matchAndRewrite(CopyOp copyOp,
811  PatternRewriter &rewriter) const override {
812  bool modified = false;
813 
814  // Check source.
815  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
816  auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
817  auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
818 
819  if (fromType && toType) {
820  if (fromType.getShape() == toType.getShape() &&
821  fromType.getElementType() == toType.getElementType()) {
822  rewriter.updateRootInPlace(copyOp, [&] {
823  copyOp.getSourceMutable().assign(castOp.getSource());
824  });
825  modified = true;
826  }
827  }
828  }
829 
830  // Check target.
831  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
832  auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
833  auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
834 
835  if (fromType && toType) {
836  if (fromType.getShape() == toType.getShape() &&
837  fromType.getElementType() == toType.getElementType()) {
838  rewriter.updateRootInPlace(copyOp, [&] {
839  copyOp.getTargetMutable().assign(castOp.getSource());
840  });
841  modified = true;
842  }
843  }
844  }
845 
846  return success(modified);
847  }
848 };
849 
850 /// Fold memref.copy(%x, %x).
851 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
853 
854  LogicalResult matchAndRewrite(CopyOp copyOp,
855  PatternRewriter &rewriter) const override {
856  if (copyOp.getSource() != copyOp.getTarget())
857  return failure();
858 
859  rewriter.eraseOp(copyOp);
860  return success();
861  }
862 };
863 } // namespace
864 
865 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
866  MLIRContext *context) {
867  results.add<FoldCopyOfCast, FoldSelfCopy>(context);
868 }
869 
870 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
872  /// copy(memrefcast) -> copy
873  bool folded = false;
874  Operation *op = *this;
875  for (OpOperand &operand : op->getOpOperands()) {
876  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
877  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
878  operand.set(castOp.getOperand());
879  folded = true;
880  }
881  }
882  return success(folded);
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // DeallocOp
887 //===----------------------------------------------------------------------===//
888 
889 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
891  /// dealloc(memrefcast) -> dealloc
892  return foldMemRefCast(*this);
893 }
894 
895 //===----------------------------------------------------------------------===//
896 // DimOp
897 //===----------------------------------------------------------------------===//
898 
899 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
900  setNameFn(getResult(), "dim");
901 }
902 
903 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
904  int64_t index) {
905  auto loc = result.location;
906  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
907  build(builder, result, source, indexValue);
908 }
909 
910 Optional<int64_t> DimOp::getConstantIndex() {
911  return getConstantIntValue(getIndex());
912 }
913 
914 Speculation::Speculatability DimOp::getSpeculatability() {
915  auto constantIndex = getConstantIndex();
916  if (!constantIndex)
918 
919  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
920  if (!rankedSourceType)
922 
923  // The verifier rejects operations that violate this assertion.
924  assert(constantIndex < rankedSourceType.getRank());
926 }
927 
929  // Assume unknown index to be in range.
930  Optional<int64_t> index = getConstantIndex();
931  if (!index)
932  return success();
933 
934  // Check that constant index is not knowingly out of range.
935  auto type = getSource().getType();
936  if (auto memrefType = type.dyn_cast<MemRefType>()) {
937  if (*index >= memrefType.getRank())
938  return emitOpError("index is out of range");
939  } else if (type.isa<UnrankedMemRefType>()) {
940  // Assume index to be in range.
941  } else {
942  llvm_unreachable("expected operand with memref type");
943  }
944  return success();
945 }
946 
947 /// Return a map with key being elements in `vals` and data being number of
948 /// occurences of it. Use std::map, since the `vals` here are strides and the
949 /// dynamic stride value is the same as the tombstone value for
950 /// `DenseMap<int64_t>`.
951 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
952  std::map<int64_t, unsigned> numOccurences;
953  for (auto val : vals)
954  numOccurences[val]++;
955  return numOccurences;
956 }
957 
958 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
959 /// to be a subset of `originalType` with some `1` entries erased, return the
960 /// set of indices that specifies which of the entries of `originalShape` are
961 /// dropped to obtain `reducedShape`.
962 /// This accounts for cases where there are multiple unit-dims, but only a
963 /// subset of those are dropped. For MemRefTypes these can be disambiguated
964 /// using the strides. If a dimension is dropped the stride must be dropped too.
966 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
967  ArrayRef<OpFoldResult> sizes) {
968  llvm::SmallBitVector unusedDims(originalType.getRank());
969  if (originalType.getRank() == reducedType.getRank())
970  return unusedDims;
971 
972  for (const auto &dim : llvm::enumerate(sizes))
973  if (auto attr = dim.value().dyn_cast<Attribute>())
974  if (attr.cast<IntegerAttr>().getInt() == 1)
975  unusedDims.set(dim.index());
976 
977  // Early exit for the case where the number of unused dims matches the number
978  // of ranks reduced.
979  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
980  originalType.getRank())
981  return unusedDims;
982 
983  SmallVector<int64_t> originalStrides, candidateStrides;
984  int64_t originalOffset, candidateOffset;
985  if (failed(
986  getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
987  failed(
988  getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
989  return std::nullopt;
990 
991  // For memrefs, a dimension is truly dropped if its corresponding stride is
992  // also dropped. This is particularly important when more than one of the dims
993  // is 1. Track the number of occurences of the strides in the original type
994  // and the candidate type. For each unused dim that stride should not be
995  // present in the candidate type. Note that there could be multiple dimensions
996  // that have the same size. We dont need to exactly figure out which dim
997  // corresponds to which stride, we just need to verify that the number of
998  // reptitions of a stride in the original + number of unused dims with that
999  // stride == number of repititions of a stride in the candidate.
1000  std::map<int64_t, unsigned> currUnaccountedStrides =
1001  getNumOccurences(originalStrides);
1002  std::map<int64_t, unsigned> candidateStridesNumOccurences =
1003  getNumOccurences(candidateStrides);
1004  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
1005  if (!unusedDims.test(dim))
1006  continue;
1007  int64_t originalStride = originalStrides[dim];
1008  if (currUnaccountedStrides[originalStride] >
1009  candidateStridesNumOccurences[originalStride]) {
1010  // This dim can be treated as dropped.
1011  currUnaccountedStrides[originalStride]--;
1012  continue;
1013  }
1014  if (currUnaccountedStrides[originalStride] ==
1015  candidateStridesNumOccurences[originalStride]) {
1016  // The stride for this is not dropped. Keep as is.
1017  unusedDims.reset(dim);
1018  continue;
1019  }
1020  if (currUnaccountedStrides[originalStride] <
1021  candidateStridesNumOccurences[originalStride]) {
1022  // This should never happen. Cant have a stride in the reduced rank type
1023  // that wasnt in the original one.
1024  return std::nullopt;
1025  }
1026  }
1027 
1028  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1029  originalType.getRank())
1030  return std::nullopt;
1031  return unusedDims;
1032 }
1033 
1034 llvm::SmallBitVector SubViewOp::getDroppedDims() {
1035  MemRefType sourceType = getSourceType();
1036  MemRefType resultType = getType();
1038  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1039  assert(unusedDims && "unable to find unused dims of subview");
1040  return *unusedDims;
1041 }
1042 
1043 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
1044  // All forms of folding require a known index.
1045  auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
1046  if (!index)
1047  return {};
1048 
1049  // Folding for unranked types (UnrankedMemRefType) is not supported.
1050  auto memrefType = getSource().getType().dyn_cast<MemRefType>();
1051  if (!memrefType)
1052  return {};
1053 
1054  // Fold if the shape extent along the given index is known.
1055  if (!memrefType.isDynamicDim(index.getInt())) {
1056  Builder builder(getContext());
1057  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1058  }
1059 
1060  // The size at the given index is now known to be a dynamic size.
1061  unsigned unsignedIndex = index.getValue().getZExtValue();
1062 
1063  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1064  Operation *definingOp = getSource().getDefiningOp();
1065 
1066  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1067  return *(alloc.getDynamicSizes().begin() +
1068  memrefType.getDynamicDimIndex(unsignedIndex));
1069 
1070  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1071  return *(alloca.getDynamicSizes().begin() +
1072  memrefType.getDynamicDimIndex(unsignedIndex));
1073 
1074  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1075  return *(view.getDynamicSizes().begin() +
1076  memrefType.getDynamicDimIndex(unsignedIndex));
1077 
1078  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1079  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1080  unsigned resultIndex = 0;
1081  unsigned sourceRank = subview.getSourceType().getRank();
1082  unsigned sourceIndex = 0;
1083  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1084  if (unusedDims.test(i))
1085  continue;
1086  if (resultIndex == unsignedIndex) {
1087  sourceIndex = i;
1088  break;
1089  }
1090  resultIndex++;
1091  }
1092  assert(subview.isDynamicSize(sourceIndex) &&
1093  "expected dynamic subview size");
1094  return subview.getDynamicSize(sourceIndex);
1095  }
1096 
1097  if (auto sizeInterface =
1098  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1099  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1100  "Expected dynamic subview size");
1101  return sizeInterface.getDynamicSize(unsignedIndex);
1102  }
1103 
1104  // dim(memrefcast) -> dim
1105  if (succeeded(foldMemRefCast(*this)))
1106  return getResult();
1107 
1108  return {};
1109 }
1110 
1111 namespace {
1112 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1113 /// operand.
1114 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1116 
1117  LogicalResult matchAndRewrite(DimOp dim,
1118  PatternRewriter &rewriter) const override {
1119  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1120 
1121  if (!reshape)
1122  return failure();
1123 
1124  // Place the load directly after the reshape to ensure that the shape memref
1125  // was not mutated.
1126  rewriter.setInsertionPointAfter(reshape);
1127  Location loc = dim.getLoc();
1128  Value load =
1129  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1130  if (load.getType() != dim.getType())
1131  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1132  rewriter.replaceOp(dim, load);
1133  return success();
1134  }
1135 };
1136 
1137 } // namespace
1138 
1139 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1140  MLIRContext *context) {
1141  results.add<DimOfMemRefReshape>(context);
1142 }
1143 
1144 // ---------------------------------------------------------------------------
1145 // DmaStartOp
1146 // ---------------------------------------------------------------------------
1147 
1148 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1149  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1150  ValueRange destIndices, Value numElements,
1151  Value tagMemRef, ValueRange tagIndices, Value stride,
1152  Value elementsPerStride) {
1153  result.addOperands(srcMemRef);
1154  result.addOperands(srcIndices);
1155  result.addOperands(destMemRef);
1156  result.addOperands(destIndices);
1157  result.addOperands({numElements, tagMemRef});
1158  result.addOperands(tagIndices);
1159  if (stride)
1160  result.addOperands({stride, elementsPerStride});
1161 }
1162 
1164  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1165  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1166  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1167  if (isStrided())
1168  p << ", " << getStride() << ", " << getNumElementsPerStride();
1169 
1170  p.printOptionalAttrDict((*this)->getAttrs());
1171  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1172  << ", " << getTagMemRef().getType();
1173 }
1174 
1175 // Parse DmaStartOp.
1176 // Ex:
1177 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1178 // %tag[%index], %stride, %num_elt_per_stride :
1179 // : memref<3076 x f32, 0>,
1180 // memref<1024 x f32, 2>,
1181 // memref<1 x i32>
1182 //
1183 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1184  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1186  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1188  OpAsmParser::UnresolvedOperand numElementsInfo;
1189  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1192 
1193  SmallVector<Type, 3> types;
1194  auto indexType = parser.getBuilder().getIndexType();
1195 
1196  // Parse and resolve the following list of operands:
1197  // *) source memref followed by its indices (in square brackets).
1198  // *) destination memref followed by its indices (in square brackets).
1199  // *) dma size in KiB.
1200  if (parser.parseOperand(srcMemRefInfo) ||
1201  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1202  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1203  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1204  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1205  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1206  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1207  return failure();
1208 
1209  // Parse optional stride and elements per stride.
1210  if (parser.parseTrailingOperandList(strideInfo))
1211  return failure();
1212 
1213  bool isStrided = strideInfo.size() == 2;
1214  if (!strideInfo.empty() && !isStrided) {
1215  return parser.emitError(parser.getNameLoc(),
1216  "expected two stride related operands");
1217  }
1218 
1219  if (parser.parseColonTypeList(types))
1220  return failure();
1221  if (types.size() != 3)
1222  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1223 
1224  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1225  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1226  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1227  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1228  // size should be an index.
1229  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1230  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1231  // tag indices should be index.
1232  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1233  return failure();
1234 
1235  if (isStrided) {
1236  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1237  return failure();
1238  }
1239 
1240  return success();
1241 }
1242 
1244  unsigned numOperands = getNumOperands();
1245 
1246  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1247  // the number of elements.
1248  if (numOperands < 4)
1249  return emitOpError("expected at least 4 operands");
1250 
1251  // Check types of operands. The order of these calls is important: the later
1252  // calls rely on some type properties to compute the operand position.
1253  // 1. Source memref.
1254  if (!getSrcMemRef().getType().isa<MemRefType>())
1255  return emitOpError("expected source to be of memref type");
1256  if (numOperands < getSrcMemRefRank() + 4)
1257  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1258  << " operands";
1259  if (!getSrcIndices().empty() &&
1260  !llvm::all_of(getSrcIndices().getTypes(),
1261  [](Type t) { return t.isIndex(); }))
1262  return emitOpError("expected source indices to be of index type");
1263 
1264  // 2. Destination memref.
1265  if (!getDstMemRef().getType().isa<MemRefType>())
1266  return emitOpError("expected destination to be of memref type");
1267  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1268  if (numOperands < numExpectedOperands)
1269  return emitOpError() << "expected at least " << numExpectedOperands
1270  << " operands";
1271  if (!getDstIndices().empty() &&
1272  !llvm::all_of(getDstIndices().getTypes(),
1273  [](Type t) { return t.isIndex(); }))
1274  return emitOpError("expected destination indices to be of index type");
1275 
1276  // 3. Number of elements.
1277  if (!getNumElements().getType().isIndex())
1278  return emitOpError("expected num elements to be of index type");
1279 
1280  // 4. Tag memref.
1281  if (!getTagMemRef().getType().isa<MemRefType>())
1282  return emitOpError("expected tag to be of memref type");
1283  numExpectedOperands += getTagMemRefRank();
1284  if (numOperands < numExpectedOperands)
1285  return emitOpError() << "expected at least " << numExpectedOperands
1286  << " operands";
1287  if (!getTagIndices().empty() &&
1288  !llvm::all_of(getTagIndices().getTypes(),
1289  [](Type t) { return t.isIndex(); }))
1290  return emitOpError("expected tag indices to be of index type");
1291 
1292  // Optional stride-related operands must be either both present or both
1293  // absent.
1294  if (numOperands != numExpectedOperands &&
1295  numOperands != numExpectedOperands + 2)
1296  return emitOpError("incorrect number of operands");
1297 
1298  // 5. Strides.
1299  if (isStrided()) {
1300  if (!getStride().getType().isIndex() ||
1301  !getNumElementsPerStride().getType().isIndex())
1302  return emitOpError(
1303  "expected stride and num elements per stride to be of type index");
1304  }
1305 
1306  return success();
1307 }
1308 
1309 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1310  SmallVectorImpl<OpFoldResult> &results) {
1311  /// dma_start(memrefcast) -> dma_start
1312  return foldMemRefCast(*this);
1313 }
1314 
1315 // ---------------------------------------------------------------------------
1316 // DmaWaitOp
1317 // ---------------------------------------------------------------------------
1318 
1319 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1320  SmallVectorImpl<OpFoldResult> &results) {
1321  /// dma_wait(memrefcast) -> dma_wait
1322  return foldMemRefCast(*this);
1323 }
1324 
1326  // Check that the number of tag indices matches the tagMemRef rank.
1327  unsigned numTagIndices = getTagIndices().size();
1328  unsigned tagMemRefRank = getTagMemRefRank();
1329  if (numTagIndices != tagMemRefRank)
1330  return emitOpError() << "expected tagIndices to have the same number of "
1331  "elements as the tagMemRef rank, expected "
1332  << tagMemRefRank << ", but got " << numTagIndices;
1333  return success();
1334 }
1335 
1336 //===----------------------------------------------------------------------===//
1337 // ExtractAlignedPointerAsIndexOp
1338 //===----------------------------------------------------------------------===//
1339 
1340 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1341  function_ref<void(Value, StringRef)> setNameFn) {
1342  setNameFn(getResult(), "intptr");
1343 }
1344 
1345 //===----------------------------------------------------------------------===//
1346 // ExtractStridedMetadataOp
1347 //===----------------------------------------------------------------------===//
1348 
1349 /// The number and type of the results are inferred from the
1350 /// shape of the source.
1351 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1352  MLIRContext *context, Optional<Location> location, ValueRange operands,
1353  DictionaryAttr attributes, RegionRange regions,
1354  SmallVectorImpl<Type> &inferredReturnTypes) {
1355  ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, regions);
1356  auto sourceType = extractAdaptor.getSource().getType().dyn_cast<MemRefType>();
1357  if (!sourceType)
1358  return failure();
1359 
1360  unsigned sourceRank = sourceType.getRank();
1361  IndexType indexType = IndexType::get(context);
1362  auto memrefType =
1363  MemRefType::get({}, sourceType.getElementType(),
1364  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1365  // Base.
1366  inferredReturnTypes.push_back(memrefType);
1367  // Offset.
1368  inferredReturnTypes.push_back(indexType);
1369  // Sizes and strides.
1370  for (unsigned i = 0; i < sourceRank * 2; ++i)
1371  inferredReturnTypes.push_back(indexType);
1372  return success();
1373 }
1374 
1375 void ExtractStridedMetadataOp::getAsmResultNames(
1376  function_ref<void(Value, StringRef)> setNameFn) {
1377  setNameFn(getBaseBuffer(), "base_buffer");
1378  setNameFn(getOffset(), "offset");
1379  // For multi-result to work properly with pretty names and packed syntax `x:3`
1380  // we can only give a pretty name to the first value in the pack.
1381  if (!getSizes().empty()) {
1382  setNameFn(getSizes().front(), "sizes");
1383  setNameFn(getStrides().front(), "strides");
1384  }
1385 }
1386 
1387 /// Helper function to perform the replacement of all constant uses of `values`
1388 /// by a materialized constant extracted from `maybeConstants`.
1389 /// `values` and `maybeConstants` are expected to have the same size.
1390 template <typename Container>
1391 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1392  Container values,
1393  ArrayRef<OpFoldResult> maybeConstants) {
1394  assert(values.size() == maybeConstants.size() &&
1395  " expected values and maybeConstants of the same size");
1396  bool atLeastOneReplacement = false;
1397  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1398  // Don't materialize a constant if there are no uses: this would indice
1399  // infinite loops in the driver.
1400  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1401  continue;
1402  assert(maybeConstant.template is<Attribute>() &&
1403  "The constified value should be either unchanged (i.e., == result) "
1404  "or a constant");
1405  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1406  loc, maybeConstant.template get<Attribute>()
1407  .template cast<IntegerAttr>()
1408  .getInt());
1409  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1410  // updateRootInplace: lambda cannot capture structured bindings in C++17
1411  // yet.
1412  op->replaceUsesOfWith(result, constantVal);
1413  atLeastOneReplacement = true;
1414  }
1415  }
1416  return atLeastOneReplacement;
1417 }
1418 
1420 ExtractStridedMetadataOp::fold(ArrayRef<Attribute> cstOperands,
1421  SmallVectorImpl<OpFoldResult> &results) {
1422  OpBuilder builder(*this);
1423 
1424  bool atLeastOneReplacement = replaceConstantUsesOf(
1425  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1426  getConstifiedMixedOffset());
1427  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1428  getConstifiedMixedSizes());
1429  atLeastOneReplacement |= replaceConstantUsesOf(
1430  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1431 
1432  return success(atLeastOneReplacement);
1433 }
1434 
1435 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1436  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1437  constifyIndexValues(values, getSource().getType(), getContext(),
1438  getConstantSizes, ShapedType::isDynamic);
1439  return values;
1440 }
1441 
1443 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1444  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1445  constifyIndexValues(values, getSource().getType(), getContext(),
1446  getConstantStrides, ShapedType::isDynamic);
1447  return values;
1448 }
1449 
1450 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1451  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1452  SmallVector<OpFoldResult> values(1, offsetOfr);
1453  constifyIndexValues(values, getSource().getType(), getContext(),
1454  getConstantOffset, ShapedType::isDynamic);
1455  return values[0];
1456 }
1457 
1458 //===----------------------------------------------------------------------===//
1459 // GenericAtomicRMWOp
1460 //===----------------------------------------------------------------------===//
1461 
1462 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1463  Value memref, ValueRange ivs) {
1464  result.addOperands(memref);
1465  result.addOperands(ivs);
1466 
1467  if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
1468  Type elementType = memrefType.getElementType();
1469  result.addTypes(elementType);
1470 
1471  Region *bodyRegion = result.addRegion();
1472  bodyRegion->push_back(new Block());
1473  bodyRegion->addArgument(elementType, memref.getLoc());
1474  }
1475 }
1476 
1478  auto &body = getRegion();
1479  if (body.getNumArguments() != 1)
1480  return emitOpError("expected single number of entry block arguments");
1481 
1482  if (getResult().getType() != body.getArgument(0).getType())
1483  return emitOpError("expected block argument of the same type result type");
1484 
1485  bool hasSideEffects =
1486  body.walk([&](Operation *nestedOp) {
1487  if (isMemoryEffectFree(nestedOp))
1488  return WalkResult::advance();
1489  nestedOp->emitError(
1490  "body of 'memref.generic_atomic_rmw' should contain "
1491  "only operations with no side effects");
1492  return WalkResult::interrupt();
1493  })
1494  .wasInterrupted();
1495  return hasSideEffects ? failure() : success();
1496 }
1497 
1498 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1499  OperationState &result) {
1501  Type memrefType;
1503 
1504  Type indexType = parser.getBuilder().getIndexType();
1505  if (parser.parseOperand(memref) ||
1507  parser.parseColonType(memrefType) ||
1508  parser.resolveOperand(memref, memrefType, result.operands) ||
1509  parser.resolveOperands(ivs, indexType, result.operands))
1510  return failure();
1511 
1512  Region *body = result.addRegion();
1513  if (parser.parseRegion(*body, {}) ||
1514  parser.parseOptionalAttrDict(result.attributes))
1515  return failure();
1516  result.types.push_back(memrefType.cast<MemRefType>().getElementType());
1517  return success();
1518 }
1519 
1521  p << ' ' << getMemref() << "[" << getIndices()
1522  << "] : " << getMemref().getType() << ' ';
1523  p.printRegion(getRegion());
1524  p.printOptionalAttrDict((*this)->getAttrs());
1525 }
1526 
1527 //===----------------------------------------------------------------------===//
1528 // AtomicYieldOp
1529 //===----------------------------------------------------------------------===//
1530 
1532  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1533  Type resultType = getResult().getType();
1534  if (parentType != resultType)
1535  return emitOpError() << "types mismatch between yield op: " << resultType
1536  << " and its parent: " << parentType;
1537  return success();
1538 }
1539 
1540 //===----------------------------------------------------------------------===//
1541 // GlobalOp
1542 //===----------------------------------------------------------------------===//
1543 
1545  TypeAttr type,
1546  Attribute initialValue) {
1547  p << type;
1548  if (!op.isExternal()) {
1549  p << " = ";
1550  if (op.isUninitialized())
1551  p << "uninitialized";
1552  else
1553  p.printAttributeWithoutType(initialValue);
1554  }
1555 }
1556 
1557 static ParseResult
1559  Attribute &initialValue) {
1560  Type type;
1561  if (parser.parseType(type))
1562  return failure();
1563 
1564  auto memrefType = type.dyn_cast<MemRefType>();
1565  if (!memrefType || !memrefType.hasStaticShape())
1566  return parser.emitError(parser.getNameLoc())
1567  << "type should be static shaped memref, but got " << type;
1568  typeAttr = TypeAttr::get(type);
1569 
1570  if (parser.parseOptionalEqual())
1571  return success();
1572 
1573  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1574  initialValue = UnitAttr::get(parser.getContext());
1575  return success();
1576  }
1577 
1578  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1579  if (parser.parseAttribute(initialValue, tensorType))
1580  return failure();
1581  if (!initialValue.isa<ElementsAttr>())
1582  return parser.emitError(parser.getNameLoc())
1583  << "initial value should be a unit or elements attribute";
1584  return success();
1585 }
1586 
1588  auto memrefType = getType().dyn_cast<MemRefType>();
1589  if (!memrefType || !memrefType.hasStaticShape())
1590  return emitOpError("type should be static shaped memref, but got ")
1591  << getType();
1592 
1593  // Verify that the initial value, if present, is either a unit attribute or
1594  // an elements attribute.
1595  if (getInitialValue().has_value()) {
1596  Attribute initValue = getInitialValue().value();
1597  if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1598  return emitOpError("initial value should be a unit or elements "
1599  "attribute, but got ")
1600  << initValue;
1601 
1602  // Check that the type of the initial value is compatible with the type of
1603  // the global variable.
1604  if (auto elementsAttr = initValue.dyn_cast<ElementsAttr>()) {
1605  Type initType = elementsAttr.getType();
1606  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1607  if (initType != tensorType)
1608  return emitOpError("initial value expected to be of type ")
1609  << tensorType << ", but was of type " << initType;
1610  }
1611  }
1612 
1613  if (Optional<uint64_t> alignAttr = getAlignment()) {
1614  uint64_t alignment = *alignAttr;
1615 
1616  if (!llvm::isPowerOf2_64(alignment))
1617  return emitError() << "alignment attribute value " << alignment
1618  << " is not a power of 2";
1619  }
1620 
1621  // TODO: verify visibility for declarations.
1622  return success();
1623 }
1624 
1625 ElementsAttr GlobalOp::getConstantInitValue() {
1626  auto initVal = getInitialValue();
1627  if (getConstant() && initVal.has_value())
1628  return initVal.value().cast<ElementsAttr>();
1629  return {};
1630 }
1631 
1632 //===----------------------------------------------------------------------===//
1633 // GetGlobalOp
1634 //===----------------------------------------------------------------------===//
1635 
1637 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1638  // Verify that the result type is same as the type of the referenced
1639  // memref.global op.
1640  auto global =
1641  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1642  if (!global)
1643  return emitOpError("'")
1644  << getName() << "' does not reference a valid global memref";
1645 
1646  Type resultType = getResult().getType();
1647  if (global.getType() != resultType)
1648  return emitOpError("result type ")
1649  << resultType << " does not match type " << global.getType()
1650  << " of the global memref @" << getName();
1651  return success();
1652 }
1653 
1654 //===----------------------------------------------------------------------===//
1655 // LoadOp
1656 //===----------------------------------------------------------------------===//
1657 
1659  if (getNumOperands() != 1 + getMemRefType().getRank())
1660  return emitOpError("incorrect number of indices for load");
1661  return success();
1662 }
1663 
1664 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1665  /// load(memrefcast) -> load
1666  if (succeeded(foldMemRefCast(*this)))
1667  return getResult();
1668  return OpFoldResult();
1669 }
1670 
1671 //===----------------------------------------------------------------------===//
1672 // PrefetchOp
1673 //===----------------------------------------------------------------------===//
1674 
1676  p << " " << getMemref() << '[';
1678  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1679  p << ", locality<" << getLocalityHint();
1680  p << ">, " << (getIsDataCache() ? "data" : "instr");
1682  (*this)->getAttrs(),
1683  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1684  p << " : " << getMemRefType();
1685 }
1686 
1687 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1688  OpAsmParser::UnresolvedOperand memrefInfo;
1690  IntegerAttr localityHint;
1691  MemRefType type;
1692  StringRef readOrWrite, cacheType;
1693 
1694  auto indexTy = parser.getBuilder().getIndexType();
1695  auto i32Type = parser.getBuilder().getIntegerType(32);
1696  if (parser.parseOperand(memrefInfo) ||
1697  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1698  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1699  parser.parseComma() || parser.parseKeyword("locality") ||
1700  parser.parseLess() ||
1701  parser.parseAttribute(localityHint, i32Type, "localityHint",
1702  result.attributes) ||
1703  parser.parseGreater() || parser.parseComma() ||
1704  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1705  parser.resolveOperand(memrefInfo, type, result.operands) ||
1706  parser.resolveOperands(indexInfo, indexTy, result.operands))
1707  return failure();
1708 
1709  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1710  return parser.emitError(parser.getNameLoc(),
1711  "rw specifier has to be 'read' or 'write'");
1712  result.addAttribute(
1713  PrefetchOp::getIsWriteAttrStrName(),
1714  parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1715 
1716  if (!cacheType.equals("data") && !cacheType.equals("instr"))
1717  return parser.emitError(parser.getNameLoc(),
1718  "cache type has to be 'data' or 'instr'");
1719 
1720  result.addAttribute(
1721  PrefetchOp::getIsDataCacheAttrStrName(),
1722  parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1723 
1724  return success();
1725 }
1726 
1728  if (getNumOperands() != 1 + getMemRefType().getRank())
1729  return emitOpError("too few indices");
1730 
1731  return success();
1732 }
1733 
1734 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1735  SmallVectorImpl<OpFoldResult> &results) {
1736  // prefetch(memrefcast) -> prefetch
1737  return foldMemRefCast(*this);
1738 }
1739 
1740 //===----------------------------------------------------------------------===//
1741 // RankOp
1742 //===----------------------------------------------------------------------===//
1743 
1744 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1745  // Constant fold rank when the rank of the operand is known.
1746  auto type = getOperand().getType();
1747  auto shapedType = type.dyn_cast<ShapedType>();
1748  if (shapedType && shapedType.hasRank())
1749  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1750  return IntegerAttr();
1751 }
1752 
1753 //===----------------------------------------------------------------------===//
1754 // ReinterpretCastOp
1755 //===----------------------------------------------------------------------===//
1756 
1757 void ReinterpretCastOp::getAsmResultNames(
1758  function_ref<void(Value, StringRef)> setNameFn) {
1759  setNameFn(getResult(), "reinterpret_cast");
1760 }
1761 
1762 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1763 /// `staticSizes` and `staticStrides` are automatically filled with
1764 /// source-memref-rank sentinel values that encode dynamic entries.
1765 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1766  MemRefType resultType, Value source,
1767  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1768  ArrayRef<OpFoldResult> strides,
1769  ArrayRef<NamedAttribute> attrs) {
1770  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1771  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1772  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1773  ShapedType::kDynamic);
1774  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1775  ShapedType::kDynamic);
1776  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1777  ShapedType::kDynamic);
1778  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1779  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1780  b.getDenseI64ArrayAttr(staticSizes),
1781  b.getDenseI64ArrayAttr(staticStrides));
1782  result.addAttributes(attrs);
1783 }
1784 
1785 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1786  MemRefType resultType, Value source,
1787  int64_t offset, ArrayRef<int64_t> sizes,
1788  ArrayRef<int64_t> strides,
1789  ArrayRef<NamedAttribute> attrs) {
1790  SmallVector<OpFoldResult> sizeValues =
1791  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1792  return b.getI64IntegerAttr(v);
1793  }));
1794  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1795  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1796  return b.getI64IntegerAttr(v);
1797  }));
1798  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1799  strideValues, attrs);
1800 }
1801 
1802 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1803  MemRefType resultType, Value source, Value offset,
1804  ValueRange sizes, ValueRange strides,
1805  ArrayRef<NamedAttribute> attrs) {
1806  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1807  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1808  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1809  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1810  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1811 }
1812 
1813 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1814 // completed automatically, like we have for subview and extract_slice.
1816  // The source and result memrefs should be in the same memory space.
1817  auto srcType = getSource().getType().cast<BaseMemRefType>();
1818  auto resultType = getType().cast<MemRefType>();
1819  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1820  return emitError("different memory spaces specified for source type ")
1821  << srcType << " and result memref type " << resultType;
1822  if (srcType.getElementType() != resultType.getElementType())
1823  return emitError("different element types specified for source type ")
1824  << srcType << " and result memref type " << resultType;
1825 
1826  // Match sizes in result memref type and in static_sizes attribute.
1827  for (auto &en :
1828  llvm::enumerate(llvm::zip(resultType.getShape(), getStaticSizes()))) {
1829  int64_t resultSize = std::get<0>(en.value());
1830  int64_t expectedSize = std::get<1>(en.value());
1831  if (!ShapedType::isDynamic(resultSize) &&
1832  !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1833  return emitError("expected result type with size = ")
1834  << expectedSize << " instead of " << resultSize
1835  << " in dim = " << en.index();
1836  }
1837 
1838  // Match offset and strides in static_offset and static_strides attributes. If
1839  // result memref type has no affine map specified, this will assume an
1840  // identity layout.
1841  int64_t resultOffset;
1842  SmallVector<int64_t, 4> resultStrides;
1843  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1844  return emitError("expected result type to have strided layout but found ")
1845  << resultType;
1846 
1847  // Match offset in result memref type and in static_offsets attribute.
1848  int64_t expectedOffset = getStaticOffsets().front();
1849  if (!ShapedType::isDynamic(resultOffset) &&
1850  !ShapedType::isDynamic(expectedOffset) &&
1851  resultOffset != expectedOffset)
1852  return emitError("expected result type with offset = ")
1853  << resultOffset << " instead of " << expectedOffset;
1854 
1855  // Match strides in result memref type and in static_strides attribute.
1856  for (auto &en :
1857  llvm::enumerate(llvm::zip(resultStrides, getStaticStrides()))) {
1858  int64_t resultStride = std::get<0>(en.value());
1859  int64_t expectedStride = std::get<1>(en.value());
1860  if (!ShapedType::isDynamic(resultStride) &&
1861  !ShapedType::isDynamic(expectedStride) &&
1862  resultStride != expectedStride)
1863  return emitError("expected result type with stride = ")
1864  << expectedStride << " instead of " << resultStride
1865  << " in dim = " << en.index();
1866  }
1867 
1868  return success();
1869 }
1870 
1871 OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
1872  Value src = getSource();
1873  auto getPrevSrc = [&]() -> Value {
1874  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1875  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1876  return prev.getSource();
1877 
1878  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1879  if (auto prev = src.getDefiningOp<CastOp>())
1880  return prev.getSource();
1881 
1882  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1883  // are 0.
1884  if (auto prev = src.getDefiningOp<SubViewOp>())
1885  if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1886  return isConstantIntValue(val, 0);
1887  }))
1888  return prev.getSource();
1889 
1890  return nullptr;
1891  };
1892 
1893  if (auto prevSrc = getPrevSrc()) {
1894  getSourceMutable().assign(prevSrc);
1895  return getResult();
1896  }
1897 
1898  return nullptr;
1899 }
1900 
1901 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1903  constifyIndexValues(values, getType(), getContext(), getConstantSizes,
1904  ShapedType::isDynamic);
1905  return values;
1906 }
1907 
1908 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1909  SmallVector<OpFoldResult> values = getMixedStrides();
1910  constifyIndexValues(values, getType(), getContext(), getConstantStrides,
1911  ShapedType::isDynamic);
1912  return values;
1913 }
1914 
1915 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1916  SmallVector<OpFoldResult> values = getMixedOffsets();
1917  assert(values.size() == 1 &&
1918  "reinterpret_cast must have one and only one offset");
1919  constifyIndexValues(values, getType(), getContext(), getConstantOffset,
1920  ShapedType::isDynamic);
1921  return values[0];
1922 }
1923 
1924 namespace {
1925 /// Replace the sequence:
1926 /// ```
1927 /// base, offset, sizes, strides = extract_strided_metadata src
1928 /// dst = reinterpret_cast base to offset, sizes, strides
1929 /// ```
1930 /// With
1931 ///
1932 /// ```
1933 /// dst = memref.cast src
1934 /// ```
1935 ///
1936 /// Note: The cast operation is only inserted when the type of dst and src
1937 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
1938 ///
1939 /// This pattern also matches when the offset, sizes, and strides don't come
1940 /// directly from the `extract_strided_metadata`'s results but it can be
1941 /// statically proven that they would hold the same values.
1942 ///
1943 /// For instance, the following sequence would be replaced:
1944 /// ```
1945 /// base, offset, sizes, strides =
1946 /// extract_strided_metadata memref : memref<3x4xty>
1947 /// dst = reinterpret_cast base to 0, [3, 4], strides
1948 /// ```
1949 /// Because we know (thanks to the type of the input memref) that variable
1950 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
1951 ///
1952 /// Similarly, the following sequence would be replaced:
1953 /// ```
1954 /// c0 = arith.constant 0
1955 /// c4 = arith.constant 4
1956 /// base, offset, sizes, strides =
1957 /// extract_strided_metadata memref : memref<3x4xty>
1958 /// dst = reinterpret_cast base to c0, [3, c4], strides
1959 /// ```
1960 /// Because we know that `offset`and `c0` will hold 0
1961 /// and `c4` will hold 4.
1962 struct ReinterpretCastOpExtractStridedMetadataFolder
1963  : public OpRewritePattern<ReinterpretCastOp> {
1964 public:
1966 
1967  LogicalResult matchAndRewrite(ReinterpretCastOp op,
1968  PatternRewriter &rewriter) const override {
1969  auto extractStridedMetadata =
1970  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
1971  if (!extractStridedMetadata)
1972  return failure();
1973  // Check if the reinterpret cast reconstructs a memref with the exact same
1974  // properties as the extract strided metadata.
1975 
1976  // First, check that the strides are the same.
1977  SmallVector<OpFoldResult> extractStridesOfr =
1978  extractStridedMetadata.getConstifiedMixedStrides();
1979  SmallVector<OpFoldResult> reinterpretStridesOfr =
1980  op.getConstifiedMixedStrides();
1981  if (extractStridesOfr.size() != reinterpretStridesOfr.size())
1982  return failure();
1983 
1984  unsigned rank = op.getType().getRank();
1985  for (unsigned i = 0; i < rank; ++i) {
1986  if (extractStridesOfr[i] != reinterpretStridesOfr[i])
1987  return failure();
1988  }
1989 
1990  // Second, check the sizes.
1991  assert(extractStridedMetadata.getSizes().size() ==
1992  op.getMixedSizes().size() &&
1993  "Strides and sizes rank must match");
1994  SmallVector<OpFoldResult> extractSizesOfr =
1995  extractStridedMetadata.getConstifiedMixedSizes();
1996  SmallVector<OpFoldResult> reinterpretSizesOfr =
1997  op.getConstifiedMixedSizes();
1998  for (unsigned i = 0; i < rank; ++i) {
1999  if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2000  return failure();
2001  }
2002  // Finally, check the offset.
2003  assert(op.getMixedOffsets().size() == 1 &&
2004  "reinterpret_cast with more than one offset should have been "
2005  "rejected by the verifier");
2006  OpFoldResult extractOffsetOfr =
2007  extractStridedMetadata.getConstifiedMixedOffset();
2008  OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2009  if (extractOffsetOfr != reinterpretOffsetOfr)
2010  return failure();
2011 
2012  // At this point, we know that the back and forth between extract strided
2013  // metadata and reinterpret cast is a noop. However, the final type of the
2014  // reinterpret cast may not be exactly the same as the original memref.
2015  // E.g., it could be changing a dimension from static to dynamic. Check that
2016  // here and add a cast if necessary.
2017  Type srcTy = extractStridedMetadata.getSource().getType();
2018  if (srcTy == op.getResult().getType())
2019  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2020  else
2021  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2022  extractStridedMetadata.getSource());
2023 
2024  return success();
2025  }
2026 };
2027 } // namespace
2028 
2029 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2030  MLIRContext *context) {
2031  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2032 }
2033 
2034 //===----------------------------------------------------------------------===//
2035 // Reassociative reshape ops
2036 //===----------------------------------------------------------------------===//
2037 
2038 void CollapseShapeOp::getAsmResultNames(
2039  function_ref<void(Value, StringRef)> setNameFn) {
2040  setNameFn(getResult(), "collapse_shape");
2041 }
2042 
2043 void ExpandShapeOp::getAsmResultNames(
2044  function_ref<void(Value, StringRef)> setNameFn) {
2045  setNameFn(getResult(), "expand_shape");
2046 }
2047 
2048 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2049 /// result and operand. Layout maps are verified separately.
2050 ///
2051 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2052 /// allowed in a reassocation group.
2053 static LogicalResult
2055  ArrayRef<int64_t> expandedShape,
2056  ArrayRef<ReassociationIndices> reassociation,
2057  bool allowMultipleDynamicDimsPerGroup) {
2058  // There must be one reassociation group per collapsed dimension.
2059  if (collapsedShape.size() != reassociation.size())
2060  return op->emitOpError("invalid number of reassociation groups: found ")
2061  << reassociation.size() << ", expected " << collapsedShape.size();
2062 
2063  // The next expected expanded dimension index (while iterating over
2064  // reassociation indices).
2065  int64_t nextDim = 0;
2066  for (const auto &it : llvm::enumerate(reassociation)) {
2067  ReassociationIndices group = it.value();
2068  int64_t collapsedDim = it.index();
2069 
2070  bool foundDynamic = false;
2071  for (int64_t expandedDim : group) {
2072  if (expandedDim != nextDim++)
2073  return op->emitOpError("reassociation indices must be contiguous");
2074 
2075  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2076  return op->emitOpError("reassociation index ")
2077  << expandedDim << " is out of bounds";
2078 
2079  // Check if there are multiple dynamic dims in a reassociation group.
2080  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2081  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2082  return op->emitOpError(
2083  "at most one dimension in a reassociation group may be dynamic");
2084  foundDynamic = true;
2085  }
2086  }
2087 
2088  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2089  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2090  return op->emitOpError("collapsed dim (")
2091  << collapsedDim
2092  << ") must be dynamic if and only if reassociation group is "
2093  "dynamic";
2094 
2095  // If all dims in the reassociation group are static, the size of the
2096  // collapsed dim can be verified.
2097  if (!foundDynamic) {
2098  int64_t groupSize = 1;
2099  for (int64_t expandedDim : group)
2100  groupSize *= expandedShape[expandedDim];
2101  if (groupSize != collapsedShape[collapsedDim])
2102  return op->emitOpError("collapsed dim size (")
2103  << collapsedShape[collapsedDim]
2104  << ") must equal reassociation group size (" << groupSize << ")";
2105  }
2106  }
2107 
2108  if (collapsedShape.empty()) {
2109  // Rank 0: All expanded dimensions must be 1.
2110  for (int64_t d : expandedShape)
2111  if (d != 1)
2112  return op->emitOpError(
2113  "rank 0 memrefs can only be extended/collapsed with/from ones");
2114  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2115  // Rank >= 1: Number of dimensions among all reassociation groups must match
2116  // the result memref rank.
2117  return op->emitOpError("expanded rank (")
2118  << expandedShape.size()
2119  << ") inconsistent with number of reassociation indices (" << nextDim
2120  << ")";
2121  }
2122 
2123  return success();
2124 }
2125 
2126 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2127  return getSymbolLessAffineMaps(getReassociationExprs());
2128 }
2129 
2130 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2131  return convertReassociationIndicesToExprs(getContext(),
2132  getReassociationIndices());
2133 }
2134 
2135 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2136  return getSymbolLessAffineMaps(getReassociationExprs());
2137 }
2138 
2139 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2140  return convertReassociationIndicesToExprs(getContext(),
2141  getReassociationIndices());
2142 }
2143 
2144 /// Compute the layout map after expanding a given source MemRef type with the
2145 /// specified reassociation indices.
2147 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2148  ArrayRef<ReassociationIndices> reassociation) {
2149  int64_t srcOffset;
2150  SmallVector<int64_t> srcStrides;
2151  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2152  return failure();
2153  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2154 
2155  // 1-1 mapping between srcStrides and reassociation packs.
2156  // Each srcStride starts with the given value and gets expanded according to
2157  // the proper entries in resultShape.
2158  // Example:
2159  // srcStrides = [10000, 1 , 100 ],
2160  // reassociations = [ [0], [1], [2, 3, 4]],
2161  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2162  // -> For the purpose of stride calculation, the useful sizes are:
2163  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2164  // resultStrides = [10000, 1, 600, 200, 100]
2165  // Note that a stride does not get expanded along the first entry of each
2166  // shape pack.
2167  SmallVector<int64_t> reverseResultStrides;
2168  reverseResultStrides.reserve(resultShape.size());
2169  unsigned shapeIndex = resultShape.size() - 1;
2170  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2171  ReassociationIndices reassoc = std::get<0>(it);
2172  int64_t currentStrideToExpand = std::get<1>(it);
2173  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2174  using saturated_arith::Wrapper;
2175  reverseResultStrides.push_back(currentStrideToExpand);
2176  currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
2177  Wrapper::size(resultShape[shapeIndex--]))
2178  .asStride();
2179  }
2180  }
2181  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2182  resultStrides.resize(resultShape.size(), 1);
2183  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2184 }
2185 
2186 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2187  MemRefType srcType, ArrayRef<int64_t> resultShape,
2188  ArrayRef<ReassociationIndices> reassociation) {
2189  if (srcType.getLayout().isIdentity()) {
2190  // If the source is contiguous (i.e., no layout map specified), so is the
2191  // result.
2192  MemRefLayoutAttrInterface layout;
2193  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2194  srcType.getMemorySpace());
2195  }
2196 
2197  // Source may not be contiguous. Compute the layout map.
2198  FailureOr<StridedLayoutAttr> computedLayout =
2199  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2200  if (failed(computedLayout))
2201  return failure();
2202  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2203  srcType.getMemorySpace());
2204 }
2205 
2206 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2207  ArrayRef<int64_t> resultShape, Value src,
2208  ArrayRef<ReassociationIndices> reassociation) {
2209  // Only ranked memref source values are supported.
2210  auto srcType = src.getType().cast<MemRefType>();
2211  FailureOr<MemRefType> resultType =
2212  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2213  // Failure of this assertion usually indicates a problem with the source
2214  // type, e.g., could not get strides/offset.
2215  assert(succeeded(resultType) && "could not compute layout");
2216  build(builder, result, *resultType, src, reassociation);
2217 }
2218 
2220  MemRefType srcType = getSrcType();
2221  MemRefType resultType = getResultType();
2222 
2223  if (srcType.getRank() >= resultType.getRank())
2224  return emitOpError("expected rank expansion, but found source rank ")
2225  << srcType.getRank() << " >= result rank " << resultType.getRank();
2226 
2227  // Verify result shape.
2228  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2229  resultType.getShape(),
2230  getReassociationIndices(),
2231  /*allowMultipleDynamicDimsPerGroup=*/false)))
2232  return failure();
2233 
2234  // Compute expected result type (including layout map).
2235  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2236  srcType, resultType.getShape(), getReassociationIndices());
2237  if (failed(expectedResultType))
2238  return emitOpError("invalid source layout map");
2239 
2240  // Check actual result type.
2241  if (*expectedResultType != resultType)
2242  return emitOpError("expected expanded type to be ")
2243  << *expectedResultType << " but found " << resultType;
2244 
2245  return success();
2246 }
2247 
2248 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2249  MLIRContext *context) {
2252  context);
2253 }
2254 
2255 /// Compute the layout map after collapsing a given source MemRef type with the
2256 /// specified reassociation indices.
2257 ///
2258 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2259 /// not possible to check this by inspecting a MemRefType in the general case.
2260 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2261 /// be valid (and thus accepted by this function) unless `strict = true`.
2263 computeCollapsedLayoutMap(MemRefType srcType,
2264  ArrayRef<ReassociationIndices> reassociation,
2265  bool strict = false) {
2266  int64_t srcOffset;
2267  SmallVector<int64_t> srcStrides;
2268  auto srcShape = srcType.getShape();
2269  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2270  return failure();
2271 
2272  // The result stride of a reassociation group is the stride of the last entry
2273  // of the reassociation. (TODO: Should be the minimum stride in the
2274  // reassociation because strides are not necessarily sorted. E.g., when using
2275  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2276  // strides are meaningless and could have any arbitrary value.
2277  SmallVector<int64_t> resultStrides;
2278  resultStrides.reserve(reassociation.size());
2279  for (const ReassociationIndices &reassoc : reassociation) {
2280  ArrayRef<int64_t> ref = llvm::makeArrayRef(reassoc);
2281  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2282  ref = ref.drop_back();
2283  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2284  resultStrides.push_back(srcStrides[ref.back()]);
2285  } else {
2286  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2287  // the corresponding stride may have to be skipped. (See above comment.)
2288  // Therefore, the result stride cannot be statically determined and must
2289  // be dynamic.
2290  resultStrides.push_back(ShapedType::kDynamic);
2291  }
2292  }
2293 
2294  // Validate that each reassociation group is contiguous.
2295  unsigned resultStrideIndex = resultStrides.size() - 1;
2296  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2297  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2298  using saturated_arith::Wrapper;
2299  auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
2300  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2301  stride = stride * Wrapper::size(srcShape[idx]);
2302 
2303  // Both source and result stride must have the same static value. In that
2304  // case, we can be sure, that the dimensions are collapsible (because they
2305  // are contiguous).
2306  // If `strict = false` (default during op verification), we accept cases
2307  // where one or both strides are dynamic. This is best effort: We reject
2308  // ops where obviously non-contiguous dims are collapsed, but accept ops
2309  // where we cannot be sure statically. Such ops may fail at runtime. See
2310  // the op documentation for details.
2311  auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
2312  if (strict && (stride.saturated || srcStride.saturated))
2313  return failure();
2314 
2315  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2316  return failure();
2317  }
2318  }
2319  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2320 }
2321 
2322 bool CollapseShapeOp::isGuaranteedCollapsible(
2323  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2324  // MemRefs with identity layout are always collapsible.
2325  if (srcType.getLayout().isIdentity())
2326  return true;
2327 
2328  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2329  /*strict=*/true));
2330 }
2331 
2332 MemRefType CollapseShapeOp::computeCollapsedType(
2333  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2334  SmallVector<int64_t> resultShape;
2335  resultShape.reserve(reassociation.size());
2336  for (const ReassociationIndices &group : reassociation) {
2337  using saturated_arith::Wrapper;
2338  auto groupSize = Wrapper::size(1);
2339  for (int64_t srcDim : group)
2340  groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
2341  resultShape.push_back(groupSize.asSize());
2342  }
2343 
2344  if (srcType.getLayout().isIdentity()) {
2345  // If the source is contiguous (i.e., no layout map specified), so is the
2346  // result.
2347  MemRefLayoutAttrInterface layout;
2348  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2349  srcType.getMemorySpace());
2350  }
2351 
2352  // Source may not be fully contiguous. Compute the layout map.
2353  // Note: Dimensions that are collapsed into a single dim are assumed to be
2354  // contiguous.
2355  FailureOr<StridedLayoutAttr> computedLayout =
2356  computeCollapsedLayoutMap(srcType, reassociation);
2357  assert(succeeded(computedLayout) &&
2358  "invalid source layout map or collapsing non-contiguous dims");
2359  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2360  srcType.getMemorySpace());
2361 }
2362 
2363 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2364  ArrayRef<ReassociationIndices> reassociation,
2365  ArrayRef<NamedAttribute> attrs) {
2366  auto srcType = src.getType().cast<MemRefType>();
2367  MemRefType resultType =
2368  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2369  build(b, result, resultType, src, attrs);
2371  getReassociationIndicesAttribute(b, reassociation));
2372 }
2373 
2375  MemRefType srcType = getSrcType();
2376  MemRefType resultType = getResultType();
2377 
2378  if (srcType.getRank() <= resultType.getRank())
2379  return emitOpError("expected rank reduction, but found source rank ")
2380  << srcType.getRank() << " <= result rank " << resultType.getRank();
2381 
2382  // Verify result shape.
2383  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2384  srcType.getShape(), getReassociationIndices(),
2385  /*allowMultipleDynamicDimsPerGroup=*/true)))
2386  return failure();
2387 
2388  // Compute expected result type (including layout map).
2389  MemRefType expectedResultType;
2390  if (srcType.getLayout().isIdentity()) {
2391  // If the source is contiguous (i.e., no layout map specified), so is the
2392  // result.
2393  MemRefLayoutAttrInterface layout;
2394  expectedResultType =
2395  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2396  srcType.getMemorySpace());
2397  } else {
2398  // Source may not be fully contiguous. Compute the layout map.
2399  // Note: Dimensions that are collapsed into a single dim are assumed to be
2400  // contiguous.
2401  FailureOr<StridedLayoutAttr> computedLayout =
2402  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2403  if (failed(computedLayout))
2404  return emitOpError(
2405  "invalid source layout map or collapsing non-contiguous dims");
2406  expectedResultType =
2407  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2408  *computedLayout, srcType.getMemorySpace());
2409  }
2410 
2411  if (expectedResultType != resultType)
2412  return emitOpError("expected collapsed type to be ")
2413  << expectedResultType << " but found " << resultType;
2414 
2415  return success();
2416 }
2417 
2419  : public OpRewritePattern<CollapseShapeOp> {
2420 public:
2422 
2423  LogicalResult matchAndRewrite(CollapseShapeOp op,
2424  PatternRewriter &rewriter) const override {
2425  auto cast = op.getOperand().getDefiningOp<CastOp>();
2426  if (!cast)
2427  return failure();
2428 
2429  if (!CastOp::canFoldIntoConsumerOp(cast))
2430  return failure();
2431 
2432  Type newResultType = CollapseShapeOp::computeCollapsedType(
2433  cast.getOperand().getType().cast<MemRefType>(),
2434  op.getReassociationIndices());
2435 
2436  if (newResultType == op.getResultType()) {
2437  rewriter.updateRootInPlace(
2438  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2439  } else {
2440  Value newOp = rewriter.create<CollapseShapeOp>(
2441  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2442  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2443  }
2444  return success();
2445  }
2446 };
2447 
2448 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2449  MLIRContext *context) {
2453 }
2454 
2455 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
2456  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
2457 }
2458 
2459 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
2460  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
2461 }
2462 
2463 //===----------------------------------------------------------------------===//
2464 // ReshapeOp
2465 //===----------------------------------------------------------------------===//
2466 
2467 void ReshapeOp::getAsmResultNames(
2468  function_ref<void(Value, StringRef)> setNameFn) {
2469  setNameFn(getResult(), "reshape");
2470 }
2471 
2473  Type operandType = getSource().getType();
2474  Type resultType = getResult().getType();
2475 
2476  Type operandElementType = operandType.cast<ShapedType>().getElementType();
2477  Type resultElementType = resultType.cast<ShapedType>().getElementType();
2478  if (operandElementType != resultElementType)
2479  return emitOpError("element types of source and destination memref "
2480  "types should be the same");
2481 
2482  if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
2483  if (!operandMemRefType.getLayout().isIdentity())
2484  return emitOpError("source memref type should have identity affine map");
2485 
2486  int64_t shapeSize = getShape().getType().cast<MemRefType>().getDimSize(0);
2487  auto resultMemRefType = resultType.dyn_cast<MemRefType>();
2488  if (resultMemRefType) {
2489  if (!resultMemRefType.getLayout().isIdentity())
2490  return emitOpError("result memref type should have identity affine map");
2491  if (shapeSize == ShapedType::kDynamic)
2492  return emitOpError("cannot use shape operand with dynamic length to "
2493  "reshape to statically-ranked memref type");
2494  if (shapeSize != resultMemRefType.getRank())
2495  return emitOpError(
2496  "length of shape operand differs from the result's memref rank");
2497  }
2498  return success();
2499 }
2500 
2501 //===----------------------------------------------------------------------===//
2502 // StoreOp
2503 //===----------------------------------------------------------------------===//
2504 
2506  if (getNumOperands() != 2 + getMemRefType().getRank())
2507  return emitOpError("store index operand count not equal to memref rank");
2508 
2509  return success();
2510 }
2511 
2512 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
2513  SmallVectorImpl<OpFoldResult> &results) {
2514  /// store(memrefcast) -> store
2515  return foldMemRefCast(*this, getValueToStore());
2516 }
2517 
2518 //===----------------------------------------------------------------------===//
2519 // SubViewOp
2520 //===----------------------------------------------------------------------===//
2521 
2522 void SubViewOp::getAsmResultNames(
2523  function_ref<void(Value, StringRef)> setNameFn) {
2524  setNameFn(getResult(), "subview");
2525 }
2526 
2527 /// A subview result type can be fully inferred from the source type and the
2528 /// static representation of offsets, sizes and strides. Special sentinels
2529 /// encode the dynamic case.
2530 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2531  ArrayRef<int64_t> staticOffsets,
2532  ArrayRef<int64_t> staticSizes,
2533  ArrayRef<int64_t> staticStrides) {
2534  unsigned rank = sourceMemRefType.getRank();
2535  (void)rank;
2536  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2537  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2538  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2539 
2540  // Extract source offset and strides.
2541  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
2542 
2543  // Compute target offset whose value is:
2544  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2545  int64_t targetOffset = sourceOffset;
2546  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2547  auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2548  using saturated_arith::Wrapper;
2549  targetOffset =
2550  (Wrapper::offset(targetOffset) +
2551  Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
2552  .asOffset();
2553  }
2554 
2555  // Compute target stride whose value is:
2556  // `sourceStrides_i * staticStrides_i`.
2557  SmallVector<int64_t, 4> targetStrides;
2558  targetStrides.reserve(staticOffsets.size());
2559  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2560  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2561  using saturated_arith::Wrapper;
2562  targetStrides.push_back(
2563  (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
2564  .asStride());
2565  }
2566 
2567  // The type is now known.
2568  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2569  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2570  targetOffset, targetStrides),
2571  sourceMemRefType.getMemorySpace());
2572 }
2573 
2574 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2575  ArrayRef<OpFoldResult> offsets,
2576  ArrayRef<OpFoldResult> sizes,
2577  ArrayRef<OpFoldResult> strides) {
2578  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2579  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2580  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2581  ShapedType::kDynamic);
2582  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2583  ShapedType::kDynamic);
2584  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2585  ShapedType::kDynamic);
2586  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2587  staticSizes, staticStrides);
2588 }
2589 
2590 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2591  MemRefType sourceRankedTensorType,
2592  ArrayRef<int64_t> offsets,
2593  ArrayRef<int64_t> sizes,
2594  ArrayRef<int64_t> strides) {
2595  auto inferredType =
2596  inferResultType(sourceRankedTensorType, offsets, sizes, strides)
2597  .cast<MemRefType>();
2598  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2599  "expected ");
2600  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2601  return inferredType;
2602 
2603  // Compute which dimensions are dropped.
2605  computeRankReductionMask(inferredType.getShape(), resultShape);
2606  assert(dimsToProject.has_value() && "invalid rank reduction");
2607 
2608  // Compute the layout and result type.
2609  auto inferredLayout = inferredType.getLayout().cast<StridedLayoutAttr>();
2610  SmallVector<int64_t> rankReducedStrides;
2611  rankReducedStrides.reserve(resultShape.size());
2612  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2613  if (!dimsToProject->contains(idx))
2614  rankReducedStrides.push_back(value);
2615  }
2616  return MemRefType::get(resultShape, inferredType.getElementType(),
2617  StridedLayoutAttr::get(inferredLayout.getContext(),
2618  inferredLayout.getOffset(),
2619  rankReducedStrides),
2620  inferredType.getMemorySpace());
2621 }
2622 
2623 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2624  MemRefType sourceRankedTensorType,
2625  ArrayRef<OpFoldResult> offsets,
2626  ArrayRef<OpFoldResult> sizes,
2627  ArrayRef<OpFoldResult> strides) {
2628  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2629  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2630  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2631  ShapedType::kDynamic);
2632  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2633  ShapedType::kDynamic);
2634  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2635  ShapedType::kDynamic);
2636  return SubViewOp::inferRankReducedResultType(
2637  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2638  staticStrides);
2639 }
2640 
2641 // Build a SubViewOp with mixed static and dynamic entries and custom result
2642 // type. If the type passed is nullptr, it is inferred.
2643 void SubViewOp::build(OpBuilder &b, OperationState &result,
2644  MemRefType resultType, Value source,
2645  ArrayRef<OpFoldResult> offsets,
2646  ArrayRef<OpFoldResult> sizes,
2647  ArrayRef<OpFoldResult> strides,
2648  ArrayRef<NamedAttribute> attrs) {
2649  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2650  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2651  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2652  ShapedType::kDynamic);
2653  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2654  ShapedType::kDynamic);
2655  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2656  ShapedType::kDynamic);
2657  auto sourceMemRefType = source.getType().cast<MemRefType>();
2658  // Structuring implementation this way avoids duplication between builders.
2659  if (!resultType) {
2660  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2661  staticSizes, staticStrides)
2662  .cast<MemRefType>();
2663  }
2664  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2665  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2666  b.getDenseI64ArrayAttr(staticSizes),
2667  b.getDenseI64ArrayAttr(staticStrides));
2668  result.addAttributes(attrs);
2669 }
2670 
2671 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2672 // type.
2673 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2674  ArrayRef<OpFoldResult> offsets,
2675  ArrayRef<OpFoldResult> sizes,
2676  ArrayRef<OpFoldResult> strides,
2677  ArrayRef<NamedAttribute> attrs) {
2678  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2679 }
2680 
2681 // Build a SubViewOp with static entries and inferred result type.
2682 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2683  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2684  ArrayRef<int64_t> strides,
2685  ArrayRef<NamedAttribute> attrs) {
2686  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2687  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2688  return b.getI64IntegerAttr(v);
2689  }));
2690  SmallVector<OpFoldResult> sizeValues =
2691  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2692  return b.getI64IntegerAttr(v);
2693  }));
2694  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2695  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2696  return b.getI64IntegerAttr(v);
2697  }));
2698  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2699 }
2700 
2701 // Build a SubViewOp with dynamic entries and custom result type. If the
2702 // type passed is nullptr, it is inferred.
2703 void SubViewOp::build(OpBuilder &b, OperationState &result,
2704  MemRefType resultType, Value source,
2705  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2706  ArrayRef<int64_t> strides,
2707  ArrayRef<NamedAttribute> attrs) {
2708  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2709  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2710  return b.getI64IntegerAttr(v);
2711  }));
2712  SmallVector<OpFoldResult> sizeValues =
2713  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2714  return b.getI64IntegerAttr(v);
2715  }));
2716  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2717  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2718  return b.getI64IntegerAttr(v);
2719  }));
2720  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2721  attrs);
2722 }
2723 
2724 // Build a SubViewOp with dynamic entries and custom result type. If the type
2725 // passed is nullptr, it is inferred.
2726 void SubViewOp::build(OpBuilder &b, OperationState &result,
2727  MemRefType resultType, Value source, ValueRange offsets,
2728  ValueRange sizes, ValueRange strides,
2729  ArrayRef<NamedAttribute> attrs) {
2730  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2731  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2732  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2733  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2734  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2735  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2736  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2737 }
2738 
2739 // Build a SubViewOp with dynamic entries and inferred result type.
2740 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2741  ValueRange offsets, ValueRange sizes, ValueRange strides,
2742  ArrayRef<NamedAttribute> attrs) {
2743  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2744 }
2745 
2746 /// For ViewLikeOpInterface.
2747 Value SubViewOp::getViewSource() { return getSource(); }
2748 
2749 /// Return true if t1 and t2 have equal offsets (both dynamic or of same
2750 /// static value).
2751 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2752  int64_t t1Offset, t2Offset;
2753  SmallVector<int64_t> t1Strides, t2Strides;
2754  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2755  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2756  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2757 }
2758 
2759 /// Checks if `original` Type type can be rank reduced to `reduced` type.
2760 /// This function is slight variant of `is subsequence` algorithm where
2761 /// not matching dimension must be 1.
2763 isRankReducedMemRefType(MemRefType originalType,
2764  MemRefType candidateRankReducedType,
2765  ArrayRef<OpFoldResult> sizes) {
2766  auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
2767  if (partialRes != SliceVerificationResult::Success)
2768  return partialRes;
2769 
2770  auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
2771  originalType, candidateRankReducedType, sizes);
2772 
2773  // Sizes cannot be matched in case empty vector is returned.
2774  if (!optionalUnusedDimsMask)
2776 
2777  if (originalType.getMemorySpace() !=
2778  candidateRankReducedType.getMemorySpace())
2780 
2781  // No amount of stride dropping can reconcile incompatible offsets.
2782  if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
2784 
2786 }
2787 
2788 template <typename OpTy>
2790  OpTy op, Type expectedType) {
2791  auto memrefType = expectedType.cast<ShapedType>();
2792  switch (result) {
2794  return success();
2796  return op.emitError("expected result rank to be smaller or equal to ")
2797  << "the source rank. ";
2799  return op.emitError("expected result type to be ")
2800  << expectedType
2801  << " or a rank-reduced version. (mismatch of result sizes) ";
2803  return op.emitError("expected result element type to be ")
2804  << memrefType.getElementType();
2806  return op.emitError("expected result and source memory spaces to match.");
2808  return op.emitError("expected result type to be ")
2809  << expectedType
2810  << " or a rank-reduced version. (mismatch of result layout) ";
2811  }
2812  llvm_unreachable("unexpected subview verification result");
2813 }
2814 
2815 /// Verifier for SubViewOp.
2817  MemRefType baseType = getSourceType();
2818  MemRefType subViewType = getType();
2819 
2820  // The base memref and the view memref should be in the same memory space.
2821  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2822  return emitError("different memory spaces specified for base memref "
2823  "type ")
2824  << baseType << " and subview memref type " << subViewType;
2825 
2826  // Verify that the base memref type has a strided layout map.
2827  if (!isStrided(baseType))
2828  return emitError("base type ") << baseType << " is not strided";
2829 
2830  // Verify result type against inferred type.
2831  auto expectedType = SubViewOp::inferResultType(
2832  baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
2833 
2834  auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
2835  subViewType, getMixedSizes());
2836  return produceSubViewErrorMsg(result, *this, expectedType);
2837 }
2838 
2839 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2840  return os << "range " << range.offset << ":" << range.size << ":"
2841  << range.stride;
2842 }
2843 
2844 /// Return the list of Range (i.e. offset, size, stride). Each Range
2845 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2846 /// with `b` at location `loc`.
2847 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2848  OpBuilder &b, Location loc) {
2849  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2850  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2851  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2853  unsigned rank = ranks[0];
2854  res.reserve(rank);
2855  for (unsigned idx = 0; idx < rank; ++idx) {
2856  Value offset =
2857  op.isDynamicOffset(idx)
2858  ? op.getDynamicOffset(idx)
2859  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
2860  Value size =
2861  op.isDynamicSize(idx)
2862  ? op.getDynamicSize(idx)
2863  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
2864  Value stride =
2865  op.isDynamicStride(idx)
2866  ? op.getDynamicStride(idx)
2867  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
2868  res.emplace_back(Range{offset, size, stride});
2869  }
2870  return res;
2871 }
2872 
2873 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2874 /// to deduce the result type for the given `sourceType`. Additionally, reduce
2875 /// the rank of the inferred result type if `currentResultType` is lower rank
2876 /// than `currentSourceType`. Use this signature if `sourceType` is updated
2877 /// together with the result type. In this case, it is important to compute
2878 /// the dropped dimensions using `currentSourceType` whose strides align with
2879 /// `currentResultType`.
2881  MemRefType currentResultType, MemRefType currentSourceType,
2882  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2883  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2884  auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2885  mixedSizes, mixedStrides)
2886  .cast<MemRefType>();
2888  computeMemRefRankReductionMask(currentSourceType, currentResultType,
2889  mixedSizes);
2890  // Return nullptr as failure mode.
2891  if (!unusedDims)
2892  return nullptr;
2893 
2894  auto layout = nonRankReducedType.getLayout().cast<StridedLayoutAttr>();
2895  SmallVector<int64_t> shape, strides;
2896  unsigned numDimsAfterReduction =
2897  nonRankReducedType.getRank() - unusedDims->count();
2898  shape.reserve(numDimsAfterReduction);
2899  strides.reserve(numDimsAfterReduction);
2900  for (const auto &[idx, size, stride] :
2901  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
2902  nonRankReducedType.getShape(), layout.getStrides())) {
2903  if (unusedDims->test(idx))
2904  continue;
2905  shape.push_back(size);
2906  strides.push_back(stride);
2907  }
2908 
2909  return MemRefType::get(shape, nonRankReducedType.getElementType(),
2910  StridedLayoutAttr::get(sourceType.getContext(),
2911  layout.getOffset(), strides),
2912  nonRankReducedType.getMemorySpace());
2913 }
2914 
2915 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2916 /// to deduce the result type. Additionally, reduce the rank of the inferred
2917 /// result type if `currentResultType` is lower rank than `sourceType`.
2919  MemRefType currentResultType, MemRefType sourceType,
2920  ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2921  ArrayRef<OpFoldResult> mixedStrides) {
2922  return getCanonicalSubViewResultType(currentResultType, sourceType,
2923  sourceType, mixedOffsets, mixedSizes,
2924  mixedStrides);
2925 }
2926 
2927 /// Helper method to check if a `subview` operation is trivially a no-op. This
2928 /// is the case if the all offsets are zero, all strides are 1, and the source
2929 /// shape is same as the size of the subview. In such cases, the subview can
2930 /// be folded into its source.
2931 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
2932  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2933  return false;
2934 
2935  auto mixedOffsets = subViewOp.getMixedOffsets();
2936  auto mixedSizes = subViewOp.getMixedSizes();
2937  auto mixedStrides = subViewOp.getMixedStrides();
2938 
2939  // Check offsets are zero.
2940  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
2941  Optional<int64_t> intValue = getConstantIntValue(ofr);
2942  return !intValue || intValue.value() != 0;
2943  }))
2944  return false;
2945 
2946  // Check strides are one.
2947  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
2948  Optional<int64_t> intValue = getConstantIntValue(ofr);
2949  return !intValue || intValue.value() != 1;
2950  }))
2951  return false;
2952 
2953  // Check all size values are static and matches the (static) source shape.
2954  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
2955  for (const auto &size : llvm::enumerate(mixedSizes)) {
2956  Optional<int64_t> intValue = getConstantIntValue(size.value());
2957  if (!intValue || *intValue != sourceShape[size.index()])
2958  return false;
2959  }
2960  // All conditions met. The `SubViewOp` is foldable as a no-op.
2961  return true;
2962 }
2963 
2964 namespace {
2965 /// Pattern to rewrite a subview op with MemRefCast arguments.
2966 /// This essentially pushes memref.cast past its consuming subview when
2967 /// `canFoldIntoConsumerOp` is true.
2968 ///
2969 /// Example:
2970 /// ```
2971 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2972 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2973 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
2974 /// ```
2975 /// is rewritten into:
2976 /// ```
2977 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2978 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
2979 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
2980 /// ```
2981 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2982 public:
2984 
2985  LogicalResult matchAndRewrite(SubViewOp subViewOp,
2986  PatternRewriter &rewriter) const override {
2987  // Any constant operand, just return to let SubViewOpConstantFolder kick
2988  // in.
2989  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2990  return matchPattern(operand, matchConstantIndex());
2991  }))
2992  return failure();
2993 
2994  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
2995  if (!castOp)
2996  return failure();
2997 
2998  if (!CastOp::canFoldIntoConsumerOp(castOp))
2999  return failure();
3000 
3001  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3002  // the MemRefCastOp source operand type to infer the result type and the
3003  // current SubViewOp source operand type to compute the dropped dimensions
3004  // if the operation is rank-reducing.
3005  auto resultType = getCanonicalSubViewResultType(
3006  subViewOp.getType(), subViewOp.getSourceType(),
3007  castOp.getSource().getType().cast<MemRefType>(),
3008  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3009  subViewOp.getMixedStrides());
3010  if (!resultType)
3011  return failure();
3012 
3013  Value newSubView = rewriter.create<SubViewOp>(
3014  subViewOp.getLoc(), resultType, castOp.getSource(),
3015  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3016  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3017  subViewOp.getStaticStrides());
3018  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3019  newSubView);
3020  return success();
3021  }
3022 };
3023 
3024 /// Canonicalize subview ops that are no-ops. When the source shape is not
3025 /// same as a result shape due to use of `affine_map`.
3026 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3027 public:
3029 
3030  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3031  PatternRewriter &rewriter) const override {
3032  if (!isTrivialSubViewOp(subViewOp))
3033  return failure();
3034  if (subViewOp.getSourceType() == subViewOp.getType()) {
3035  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3036  return success();
3037  }
3038  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3039  subViewOp.getSource());
3040  return success();
3041  }
3042 };
3043 } // namespace
3044 
3045 /// Return the canonical type of the result of a subview.
3047  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3048  ArrayRef<OpFoldResult> mixedSizes,
3049  ArrayRef<OpFoldResult> mixedStrides) {
3050  return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
3051  mixedOffsets, mixedSizes,
3052  mixedStrides);
3053  }
3054 };
3055 
3056 /// A canonicalizer wrapper to replace SubViewOps.
3058  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3059  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3060  }
3061 };
3062 
3063 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3064  MLIRContext *context) {
3065  results
3068  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3069 }
3070 
3071 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
3072  auto resultShapedType = getResult().getType().cast<ShapedType>();
3073  auto sourceShapedType = getSource().getType().cast<ShapedType>();
3074 
3075  if (resultShapedType.hasStaticShape() &&
3076  resultShapedType == sourceShapedType) {
3077  return getViewSource();
3078  }
3079 
3080  return {};
3081 }
3082 
3083 //===----------------------------------------------------------------------===//
3084 // TransposeOp
3085 //===----------------------------------------------------------------------===//
3086 
3087 void TransposeOp::getAsmResultNames(
3088  function_ref<void(Value, StringRef)> setNameFn) {
3089  setNameFn(getResult(), "transpose");
3090 }
3091 
3092 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
3093 static MemRefType inferTransposeResultType(MemRefType memRefType,
3094  AffineMap permutationMap) {
3095  auto rank = memRefType.getRank();
3096  auto originalSizes = memRefType.getShape();
3097  auto [originalStrides, offset] = getStridesAndOffset(memRefType);
3098  assert(originalStrides.size() == static_cast<unsigned>(rank));
3099 
3100  // Compute permuted sizes and strides.
3101  SmallVector<int64_t> sizes(rank, 0);
3102  SmallVector<int64_t> strides(rank, 1);
3103  for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
3104  unsigned position = en.value().cast<AffineDimExpr>().getPosition();
3105  sizes[en.index()] = originalSizes[position];
3106  strides[en.index()] = originalStrides[position];
3107  }
3108 
3109  return MemRefType::Builder(memRefType)
3110  .setShape(sizes)
3111  .setLayout(
3112  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3113 }
3114 
3115 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3116  AffineMapAttr permutation,
3117  ArrayRef<NamedAttribute> attrs) {
3118  auto permutationMap = permutation.getValue();
3119  assert(permutationMap);
3120 
3121  auto memRefType = in.getType().cast<MemRefType>();
3122  // Compute result type.
3123  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3124 
3125  build(b, result, resultType, in, attrs);
3126  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3127 }
3128 
3129 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3131  p << " " << getIn() << " " << getPermutation();
3132  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3133  p << " : " << getIn().getType() << " to " << getType();
3134 }
3135 
3136 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3138  AffineMap permutation;
3139  MemRefType srcType, dstType;
3140  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3141  parser.parseOptionalAttrDict(result.attributes) ||
3142  parser.parseColonType(srcType) ||
3143  parser.resolveOperand(in, srcType, result.operands) ||
3144  parser.parseKeywordType("to", dstType) ||
3145  parser.addTypeToList(dstType, result.types))
3146  return failure();
3147 
3148  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3149  AffineMapAttr::get(permutation));
3150  return success();
3151 }
3152 
3154  if (!getPermutation().isPermutation())
3155  return emitOpError("expected a permutation map");
3156  if (getPermutation().getNumDims() != getShapedType().getRank())
3157  return emitOpError("expected a permutation map of same rank as the input");
3158 
3159  auto srcType = getIn().getType().cast<MemRefType>();
3160  auto dstType = getType().cast<MemRefType>();
3161  auto transposedType = inferTransposeResultType(srcType, getPermutation());
3162  if (dstType != transposedType)
3163  return emitOpError("output type ")
3164  << dstType << " does not match transposed input type " << srcType
3165  << ", " << transposedType;
3166  return success();
3167 }
3168 
3169 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
3170  if (succeeded(foldMemRefCast(*this)))
3171  return getResult();
3172  return {};
3173 }
3174 
3175 //===----------------------------------------------------------------------===//
3176 // ViewOp
3177 //===----------------------------------------------------------------------===//
3178 
3179 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3180  setNameFn(getResult(), "view");
3181 }
3182 
3184  auto baseType = getOperand(0).getType().cast<MemRefType>();
3185  auto viewType = getType();
3186 
3187  // The base memref should have identity layout map (or none).
3188  if (!baseType.getLayout().isIdentity())
3189  return emitError("unsupported map for base memref type ") << baseType;
3190 
3191  // The result memref should have identity layout map (or none).
3192  if (!viewType.getLayout().isIdentity())
3193  return emitError("unsupported map for result memref type ") << viewType;
3194 
3195  // The base memref and the view memref should be in the same memory space.
3196  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3197  return emitError("different memory spaces specified for base memref "
3198  "type ")
3199  << baseType << " and view memref type " << viewType;
3200 
3201  // Verify that we have the correct number of sizes for the result type.
3202  unsigned numDynamicDims = viewType.getNumDynamicDims();
3203  if (getSizes().size() != numDynamicDims)
3204  return emitError("incorrect number of size operands for type ") << viewType;
3205 
3206  return success();
3207 }
3208 
3209 Value ViewOp::getViewSource() { return getSource(); }
3210 
3211 namespace {
3212 
3213 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3215 
3216  LogicalResult matchAndRewrite(ViewOp viewOp,
3217  PatternRewriter &rewriter) const override {
3218  // Return if none of the operands are constants.
3219  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3220  return matchPattern(operand, matchConstantIndex());
3221  }))
3222  return failure();
3223 
3224  // Get result memref type.
3225  auto memrefType = viewOp.getType();
3226 
3227  // Get offset from old memref view type 'memRefType'.
3228  int64_t oldOffset;
3229  SmallVector<int64_t, 4> oldStrides;
3230  if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3231  return failure();
3232  assert(oldOffset == 0 && "Expected 0 offset");
3233 
3234  SmallVector<Value, 4> newOperands;
3235 
3236  // Offset cannot be folded into result type.
3237 
3238  // Fold any dynamic dim operands which are produced by a constant.
3239  SmallVector<int64_t, 4> newShapeConstants;
3240  newShapeConstants.reserve(memrefType.getRank());
3241 
3242  unsigned dynamicDimPos = 0;
3243  unsigned rank = memrefType.getRank();
3244  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3245  int64_t dimSize = memrefType.getDimSize(dim);
3246  // If this is already static dimension, keep it.
3247  if (!ShapedType::isDynamic(dimSize)) {
3248  newShapeConstants.push_back(dimSize);
3249  continue;
3250  }
3251  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3252  if (auto constantIndexOp =
3253  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3254  // Dynamic shape dimension will be folded.
3255  newShapeConstants.push_back(constantIndexOp.value());
3256  } else {
3257  // Dynamic shape dimension not folded; copy operand from old memref.
3258  newShapeConstants.push_back(dimSize);
3259  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3260  }
3261  dynamicDimPos++;
3262  }
3263 
3264  // Create new memref type with constant folded dims.
3265  MemRefType newMemRefType =
3266  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3267  // Nothing new, don't fold.
3268  if (newMemRefType == memrefType)
3269  return failure();
3270 
3271  // Create new ViewOp.
3272  auto newViewOp = rewriter.create<ViewOp>(
3273  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3274  viewOp.getByteShift(), newOperands);
3275  // Insert a cast so we have the same type as the old memref type.
3276  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3277  return success();
3278  }
3279 };
3280 
3281 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3283 
3284  LogicalResult matchAndRewrite(ViewOp viewOp,
3285  PatternRewriter &rewriter) const override {
3286  Value memrefOperand = viewOp.getOperand(0);
3287  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3288  if (!memrefCastOp)
3289  return failure();
3290  Value allocOperand = memrefCastOp.getOperand();
3291  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3292  if (!allocOp)
3293  return failure();
3294  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3295  viewOp.getByteShift(),
3296  viewOp.getSizes());
3297  return success();
3298  }
3299 };
3300 
3301 } // namespace
3302 
3303 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3304  MLIRContext *context) {
3305  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3306 }
3307 
3308 //===----------------------------------------------------------------------===//
3309 // AtomicRMWOp
3310 //===----------------------------------------------------------------------===//
3311 
3313  if (getMemRefType().getRank() != getNumOperands() - 2)
3314  return emitOpError(
3315  "expects the number of subscripts to be equal to memref rank");
3316  switch (getKind()) {
3317  case arith::AtomicRMWKind::addf:
3318  case arith::AtomicRMWKind::maxf:
3319  case arith::AtomicRMWKind::minf:
3320  case arith::AtomicRMWKind::mulf:
3321  if (!getValue().getType().isa<FloatType>())
3322  return emitOpError() << "with kind '"
3323  << arith::stringifyAtomicRMWKind(getKind())
3324  << "' expects a floating-point type";
3325  break;
3326  case arith::AtomicRMWKind::addi:
3327  case arith::AtomicRMWKind::maxs:
3328  case arith::AtomicRMWKind::maxu:
3329  case arith::AtomicRMWKind::mins:
3330  case arith::AtomicRMWKind::minu:
3331  case arith::AtomicRMWKind::muli:
3332  case arith::AtomicRMWKind::ori:
3333  case arith::AtomicRMWKind::andi:
3334  if (!getValue().getType().isa<IntegerType>())
3335  return emitOpError() << "with kind '"
3336  << arith::stringifyAtomicRMWKind(getKind())
3337  << "' expects an integer type";
3338  break;
3339  default:
3340  break;
3341  }
3342  return success();
3343 }
3344 
3345 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
3346  /// atomicrmw(memrefcast) -> atomicrmw
3347  if (succeeded(foldMemRefCast(*this, getValue())))
3348  return getResult();
3349  return OpFoldResult();
3350 }
3351 
3352 //===----------------------------------------------------------------------===//
3353 // TableGen'd op method definitions
3354 //===----------------------------------------------------------------------===//
3355 
3356 #define GET_OP_CLASSES
3357 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:57
static constexpr const bool value
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
Definition: MemRefOps.cpp:188
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
Definition: MemRefOps.cpp:141
static SliceVerificationResult isRankReducedMemRefType(MemRefType originalType, MemRefType candidateRankReducedType, ArrayRef< OpFoldResult > sizes)
Checks if original Type type can be rank reduced to reduced type.
Definition: MemRefOps.cpp:2763
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1544
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2054
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:477
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap tp memRefType.
Definition: MemRefOps.cpp:3093
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:458
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2751
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1391
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:2880
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType)
Definition: MemRefOps.cpp:2789
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
Definition: MemRefOps.cpp:201
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1558
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
Definition: MemRefOps.cpp:180
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:2931
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:500
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:951
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2147
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2263
static llvm::Optional< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:966
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:226
static int resultIndex(int i)
Definition: Operator.cpp:346
Operation::operand_range getIndices(Operation *op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1252
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:319
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
U cast() const
Definition: Attributes.h:137
bool isa() const
Casting utility functions.
Definition: Attributes.h:117
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:114
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
Operation & front()
Definition: Block.h:142
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:49
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:157
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:113
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:72
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:101
IndexType getIndexType()
Definition: Builders.cpp:56
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:166
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:187
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:177
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:510
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:364
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
This class represents an operand of an operation.
Definition: Value.h:247
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite a subview op with constant arguments.
Definition: Utils.h:40
type_range getType() const
Definition: ValueRange.cpp:30
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:184
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:532
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:300
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:161
result_range getResults()
Definition: Operation.h:332
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:610
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:331
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
BlockListType & getBlocks()
Definition: Region.h:45
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=std::nullopt)
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:245
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
bool isIndex() const
Definition: Types.cpp:30
U dyn_cast() const
Definition: Types.h:270
bool isa() const
Definition: Types.h:260
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:89
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:104
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:89
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:245
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:172
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:46
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:346
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
llvm::Optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:483
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:2847
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec, int64_t sentinel)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool operator!=(StringAttr lhs, std::nullptr_t)
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
AffineExpr operator+(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:244
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:245
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:372
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Idiomatic saturated operations on offsets, sizes and strides.
Definition: MemRefOps.cpp:31
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:547
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:550
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:507
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:510
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2423
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3057
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3058
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3046
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3047
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset