MLIR  16.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 namespace {
30 /// Idiomatic saturated operations on offsets, sizes and strides.
31 namespace saturated_arith {
32 struct Wrapper {
33  static Wrapper stride(int64_t v) {
34  return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0}
35  : Wrapper{false, v};
36  }
37  static Wrapper offset(int64_t v) {
38  return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0}
39  : Wrapper{false, v};
40  }
41  static Wrapper size(int64_t v) {
42  return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
43  }
44  int64_t asOffset() {
45  return saturated ? ShapedType::kDynamicStrideOrOffset : v;
46  }
47  int64_t asSize() { return saturated ? ShapedType::kDynamicSize : v; }
48  int64_t asStride() {
49  return saturated ? ShapedType::kDynamicStrideOrOffset : v;
50  }
51  bool operator==(Wrapper other) {
52  return (saturated && other.saturated) ||
53  (!saturated && !other.saturated && v == other.v);
54  }
55  bool operator!=(Wrapper other) { return !(*this == other); }
56  Wrapper operator+(Wrapper other) {
57  if (saturated || other.saturated)
58  return Wrapper{true, 0};
59  return Wrapper{false, other.v + v};
60  }
61  Wrapper operator*(Wrapper other) {
62  if (saturated || other.saturated)
63  return Wrapper{true, 0};
64  return Wrapper{false, other.v * v};
65  }
66  bool saturated;
67  int64_t v;
68 };
69 } // namespace saturated_arith
70 } // namespace
71 
72 /// Materialize a single constant operation from a given attribute value with
73 /// the desired resultant type.
75  Attribute value, Type type,
76  Location loc) {
77  if (arith::ConstantOp::isBuildableWith(value, type))
78  return builder.create<arith::ConstantOp>(loc, value, type);
79  return nullptr;
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Common canonicalization pattern support logic
84 //===----------------------------------------------------------------------===//
85 
86 /// This is a common class used for patterns of the form
87 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
88 /// into the root operation directly.
90  bool folded = false;
91  for (OpOperand &operand : op->getOpOperands()) {
92  auto cast = operand.get().getDefiningOp<CastOp>();
93  if (cast && operand.get() != inner &&
94  !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
95  operand.set(cast.getOperand());
96  folded = true;
97  }
98  }
99  return success(folded);
100 }
101 
102 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
103 /// type.
105  if (auto memref = type.dyn_cast<MemRefType>())
106  return RankedTensorType::get(memref.getShape(), memref.getElementType());
107  if (auto memref = type.dyn_cast<UnrankedMemRefType>())
108  return UnrankedTensorType::get(memref.getElementType());
109  return NoneType::get(type.getContext());
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // AllocOp / AllocaOp
114 //===----------------------------------------------------------------------===//
115 
116 template <typename AllocLikeOp>
117 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
119  "applies to only alloc or alloca");
120  auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
121  if (!memRefType)
122  return op.emitOpError("result must be a memref");
123 
124  if (static_cast<int64_t>(op.getDynamicSizes().size()) !=
125  memRefType.getNumDynamicDims())
126  return op.emitOpError("dimension operand count does not equal memref "
127  "dynamic dimension count");
128 
129  unsigned numSymbols = 0;
130  if (!memRefType.getLayout().isIdentity())
131  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
132  if (op.getSymbolOperands().size() != numSymbols)
133  return op.emitOpError("symbol operand count does not equal memref symbol "
134  "count: expected ")
135  << numSymbols << ", got " << op.getSymbolOperands().size();
136 
137  return success();
138 }
139 
141 
143  // An alloca op needs to have an ancestor with an allocation scope trait.
144  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
145  return emitOpError(
146  "requires an ancestor op with AutomaticAllocationScope trait");
147 
148  return verifyAllocLikeOp(*this);
149 }
150 
151 namespace {
152 /// Fold constant dimensions into an alloc like operation.
153 template <typename AllocLikeOp>
154 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
156 
157  LogicalResult matchAndRewrite(AllocLikeOp alloc,
158  PatternRewriter &rewriter) const override {
159  // Check to see if any dimensions operands are constants. If so, we can
160  // substitute and drop them.
161  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
162  return matchPattern(operand, matchConstantIndex());
163  }))
164  return failure();
165 
166  auto memrefType = alloc.getType();
167 
168  // Ok, we have one or more constant operands. Collect the non-constant ones
169  // and keep track of the resultant memref type to build.
170  SmallVector<int64_t, 4> newShapeConstants;
171  newShapeConstants.reserve(memrefType.getRank());
172  SmallVector<Value, 4> dynamicSizes;
173 
174  unsigned dynamicDimPos = 0;
175  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
176  int64_t dimSize = memrefType.getDimSize(dim);
177  // If this is already static dimension, keep it.
178  if (dimSize != -1) {
179  newShapeConstants.push_back(dimSize);
180  continue;
181  }
182  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
183  auto *defOp = dynamicSize.getDefiningOp();
184  if (auto constantIndexOp =
185  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
186  // Dynamic shape dimension will be folded.
187  newShapeConstants.push_back(constantIndexOp.value());
188  } else {
189  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
190  newShapeConstants.push_back(-1);
191  dynamicSizes.push_back(dynamicSize);
192  }
193  dynamicDimPos++;
194  }
195 
196  // Create new memref type (which will have fewer dynamic dimensions).
197  MemRefType newMemRefType =
198  MemRefType::Builder(memrefType).setShape(newShapeConstants);
199  assert(static_cast<int64_t>(dynamicSizes.size()) ==
200  newMemRefType.getNumDynamicDims());
201 
202  // Create and insert the alloc op for the new memref.
203  auto newAlloc = rewriter.create<AllocLikeOp>(
204  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
205  alloc.getAlignmentAttr());
206  // Insert a cast so we have the same type as the old alloc.
207  auto resultCast =
208  rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
209 
210  rewriter.replaceOp(alloc, {resultCast});
211  return success();
212  }
213 };
214 
215 /// Fold alloc operations with no users or only store and dealloc uses.
216 template <typename T>
217 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
219 
220  LogicalResult matchAndRewrite(T alloc,
221  PatternRewriter &rewriter) const override {
222  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
223  if (auto storeOp = dyn_cast<StoreOp>(op))
224  return storeOp.getValue() == alloc;
225  return !isa<DeallocOp>(op);
226  }))
227  return failure();
228 
229  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
230  rewriter.eraseOp(user);
231 
232  rewriter.eraseOp(alloc);
233  return success();
234  }
235 };
236 } // namespace
237 
238 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
239  MLIRContext *context) {
240  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
241 }
242 
243 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
244  MLIRContext *context) {
245  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
246  context);
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // AllocaScopeOp
251 //===----------------------------------------------------------------------===//
252 
254  bool printBlockTerminators = false;
255 
256  p << ' ';
257  if (!getResults().empty()) {
258  p << " -> (" << getResultTypes() << ")";
259  printBlockTerminators = true;
260  }
261  p << ' ';
262  p.printRegion(getBodyRegion(),
263  /*printEntryBlockArgs=*/false,
264  /*printBlockTerminators=*/printBlockTerminators);
265  p.printOptionalAttrDict((*this)->getAttrs());
266 }
267 
268 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
269  // Create a region for the body.
270  result.regions.reserve(1);
271  Region *bodyRegion = result.addRegion();
272 
273  // Parse optional results type list.
274  if (parser.parseOptionalArrowTypeList(result.types))
275  return failure();
276 
277  // Parse the body region.
278  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
279  return failure();
280  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
281  result.location);
282 
283  // Parse the optional attribute list.
284  if (parser.parseOptionalAttrDict(result.attributes))
285  return failure();
286 
287  return success();
288 }
289 
290 void AllocaScopeOp::getSuccessorRegions(
291  Optional<unsigned> index, ArrayRef<Attribute> operands,
293  if (index) {
294  regions.push_back(RegionSuccessor(getResults()));
295  return;
296  }
297 
298  regions.push_back(RegionSuccessor(&getBodyRegion()));
299 }
300 
301 /// Given an operation, return whether this op is guaranteed to
302 /// allocate an AutomaticAllocationScopeResource
304  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
305  if (!interface)
306  return false;
307  for (auto res : op->getResults()) {
308  if (auto effect =
309  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
310  if (isa<SideEffects::AutomaticAllocationScopeResource>(
311  effect->getResource()))
312  return true;
313  }
314  }
315  return false;
316 }
317 
318 /// Given an operation, return whether this op itself could
319 /// allocate an AutomaticAllocationScopeResource. Note that
320 /// this will not check whether an operation contained within
321 /// the op can allocate.
323  // This op itself doesn't create a stack allocation,
324  // the inner allocation should be handled separately.
326  return false;
327  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
328  if (!interface)
329  return true;
330  for (auto res : op->getResults()) {
331  if (auto effect =
332  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
333  if (isa<SideEffects::AutomaticAllocationScopeResource>(
334  effect->getResource()))
335  return true;
336  }
337  }
338  return false;
339 }
340 
341 /// Return whether this op is the last non terminating op
342 /// in a region. That is to say, it is in a one-block region
343 /// and is only followed by a terminator. This prevents
344 /// extending the lifetime of allocations.
346  return op->getNextNode() == op->getBlock()->getTerminator() &&
347  op->getParentRegion()->getBlocks().size() == 1;
348 }
349 
350 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
351 /// or it contains no allocation.
352 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
354 
355  LogicalResult matchAndRewrite(AllocaScopeOp op,
356  PatternRewriter &rewriter) const override {
357  bool hasPotentialAlloca =
358  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
359  if (alloc == op)
360  return WalkResult::advance();
362  return WalkResult::interrupt();
363  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
364  return WalkResult::skip();
365  return WalkResult::advance();
366  }).wasInterrupted();
367 
368  // If this contains no potential allocation, it is always legal to
369  // inline. Otherwise, consider two conditions:
370  if (hasPotentialAlloca) {
371  // If the parent isn't an allocation scope, or we are not the last
372  // non-terminator op in the parent, we will extend the lifetime.
373  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
374  return failure();
375  if (!lastNonTerminatorInRegion(op))
376  return failure();
377  }
378 
379  Block *block = &op.getRegion().front();
380  Operation *terminator = block->getTerminator();
381  ValueRange results = terminator->getOperands();
382  rewriter.mergeBlockBefore(block, op);
383  rewriter.replaceOp(op, results);
384  rewriter.eraseOp(terminator);
385  return success();
386  }
387 };
388 
389 /// Move allocations into an allocation scope, if it is legal to
390 /// move them (e.g. their operands are available at the location
391 /// the op would be moved to).
392 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
394 
395  LogicalResult matchAndRewrite(AllocaScopeOp op,
396  PatternRewriter &rewriter) const override {
397 
398  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
399  return failure();
400 
401  Operation *lastParentWithoutScope = op->getParentOp();
402 
403  if (!lastParentWithoutScope ||
404  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
405  return failure();
406 
407  // Only apply to if this is this last non-terminator
408  // op in the block (lest lifetime be extended) of a one
409  // block region
410  if (!lastNonTerminatorInRegion(op) ||
411  !lastNonTerminatorInRegion(lastParentWithoutScope))
412  return failure();
413 
414  while (!lastParentWithoutScope->getParentOp()
416  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
417  if (!lastParentWithoutScope ||
418  !lastNonTerminatorInRegion(lastParentWithoutScope))
419  return failure();
420  }
421  assert(lastParentWithoutScope->getParentOp()
423 
424  Region *containingRegion = nullptr;
425  for (auto &r : lastParentWithoutScope->getRegions()) {
426  if (r.isAncestor(op->getParentRegion())) {
427  assert(containingRegion == nullptr &&
428  "only one region can contain the op");
429  containingRegion = &r;
430  }
431  }
432  assert(containingRegion && "op must be contained in a region");
433 
434  SmallVector<Operation *> toHoist;
435  op->walk([&](Operation *alloc) {
437  return WalkResult::skip();
438 
439  // If any operand is not defined before the location of
440  // lastParentWithoutScope (i.e. where we would hoist to), skip.
441  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
442  return containingRegion->isAncestor(v.getParentRegion());
443  }))
444  return WalkResult::skip();
445  toHoist.push_back(alloc);
446  return WalkResult::advance();
447  });
448 
449  if (toHoist.empty())
450  return failure();
451  rewriter.setInsertionPoint(lastParentWithoutScope);
452  for (auto *op : toHoist) {
453  auto *cloned = rewriter.clone(*op);
454  rewriter.replaceOp(op, cloned->getResults());
455  }
456  return success();
457  }
458 };
459 
460 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
461  MLIRContext *context) {
462  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // AssumeAlignmentOp
467 //===----------------------------------------------------------------------===//
468 
470  if (!llvm::isPowerOf2_32(getAlignment()))
471  return emitOpError("alignment must be power of 2");
472  return success();
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // CastOp
477 //===----------------------------------------------------------------------===//
478 
479 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
480 /// source memref. This is useful to to fold a memref.cast into a consuming op
481 /// and implement canonicalization patterns for ops in different dialects that
482 /// may consume the results of memref.cast operations. Such foldable memref.cast
483 /// operations are typically inserted as `view` and `subview` ops are
484 /// canonicalized, to preserve the type compatibility of their uses.
485 ///
486 /// Returns true when all conditions are met:
487 /// 1. source and result are ranked memrefs with strided semantics and same
488 /// element type and rank.
489 /// 2. each of the source's size, offset or stride has more static information
490 /// than the corresponding result's size, offset or stride.
491 ///
492 /// Example 1:
493 /// ```mlir
494 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
495 /// %2 = consumer %1 ... : memref<?x?xf32> ...
496 /// ```
497 ///
498 /// may fold into:
499 ///
500 /// ```mlir
501 /// %2 = consumer %0 ... : memref<8x16xf32> ...
502 /// ```
503 ///
504 /// Example 2:
505 /// ```
506 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
507 /// to memref<?x?xf32>
508 /// consumer %1 : memref<?x?xf32> ...
509 /// ```
510 ///
511 /// may fold into:
512 ///
513 /// ```
514 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
515 /// ```
516 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
517  MemRefType sourceType = castOp.getSource().getType().dyn_cast<MemRefType>();
518  MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
519 
520  // Requires ranked MemRefType.
521  if (!sourceType || !resultType)
522  return false;
523 
524  // Requires same elemental type.
525  if (sourceType.getElementType() != resultType.getElementType())
526  return false;
527 
528  // Requires same rank.
529  if (sourceType.getRank() != resultType.getRank())
530  return false;
531 
532  // Only fold casts between strided memref forms.
533  int64_t sourceOffset, resultOffset;
534  SmallVector<int64_t, 4> sourceStrides, resultStrides;
535  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
536  failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
537  return false;
538 
539  // If cast is towards more static sizes along any dimension, don't fold.
540  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
541  auto ss = std::get<0>(it), st = std::get<1>(it);
542  if (ss != st)
543  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
544  return false;
545  }
546 
547  // If cast is towards more static offset along any dimension, don't fold.
548  if (sourceOffset != resultOffset)
549  if (ShapedType::isDynamicStrideOrOffset(sourceOffset) &&
550  !ShapedType::isDynamicStrideOrOffset(resultOffset))
551  return false;
552 
553  // If cast is towards more static strides along any dimension, don't fold.
554  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
555  auto ss = std::get<0>(it), st = std::get<1>(it);
556  if (ss != st)
557  if (ShapedType::isDynamicStrideOrOffset(ss) &&
558  !ShapedType::isDynamicStrideOrOffset(st))
559  return false;
560  }
561 
562  return true;
563 }
564 
565 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
566  if (inputs.size() != 1 || outputs.size() != 1)
567  return false;
568  Type a = inputs.front(), b = outputs.front();
569  auto aT = a.dyn_cast<MemRefType>();
570  auto bT = b.dyn_cast<MemRefType>();
571 
572  auto uaT = a.dyn_cast<UnrankedMemRefType>();
573  auto ubT = b.dyn_cast<UnrankedMemRefType>();
574 
575  if (aT && bT) {
576  if (aT.getElementType() != bT.getElementType())
577  return false;
578  if (aT.getLayout() != bT.getLayout()) {
579  int64_t aOffset, bOffset;
580  SmallVector<int64_t, 4> aStrides, bStrides;
581  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
582  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
583  aStrides.size() != bStrides.size())
584  return false;
585 
586  // Strides along a dimension/offset are compatible if the value in the
587  // source memref is static and the value in the target memref is the
588  // same. They are also compatible if either one is dynamic (see
589  // description of MemRefCastOp for details).
590  auto checkCompatible = [](int64_t a, int64_t b) {
591  return (a == MemRefType::getDynamicStrideOrOffset() ||
592  b == MemRefType::getDynamicStrideOrOffset() || a == b);
593  };
594  if (!checkCompatible(aOffset, bOffset))
595  return false;
596  for (const auto &aStride : enumerate(aStrides))
597  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
598  return false;
599  }
600  if (aT.getMemorySpace() != bT.getMemorySpace())
601  return false;
602 
603  // They must have the same rank, and any specified dimensions must match.
604  if (aT.getRank() != bT.getRank())
605  return false;
606 
607  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
608  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
609  if (aDim != -1 && bDim != -1 && aDim != bDim)
610  return false;
611  }
612  return true;
613  } else {
614  if (!aT && !uaT)
615  return false;
616  if (!bT && !ubT)
617  return false;
618  // Unranked to unranked casting is unsupported
619  if (uaT && ubT)
620  return false;
621 
622  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
623  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
624  if (aEltType != bEltType)
625  return false;
626 
627  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
628  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
629  return aMemSpace == bMemSpace;
630  }
631 
632  return false;
633 }
634 
635 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
636  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // CopyOp
641 //===----------------------------------------------------------------------===//
642 
643 namespace {
644 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
645 /// and element type, the cast can be skipped. Such CastOps only cast the layout
646 /// of the type.
647 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
649 
650  LogicalResult matchAndRewrite(CopyOp copyOp,
651  PatternRewriter &rewriter) const override {
652  bool modified = false;
653 
654  // Check source.
655  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
656  auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
657  auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
658 
659  if (fromType && toType) {
660  if (fromType.getShape() == toType.getShape() &&
661  fromType.getElementType() == toType.getElementType()) {
662  rewriter.updateRootInPlace(copyOp, [&] {
663  copyOp.getSourceMutable().assign(castOp.getSource());
664  });
665  modified = true;
666  }
667  }
668  }
669 
670  // Check target.
671  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
672  auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
673  auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
674 
675  if (fromType && toType) {
676  if (fromType.getShape() == toType.getShape() &&
677  fromType.getElementType() == toType.getElementType()) {
678  rewriter.updateRootInPlace(copyOp, [&] {
679  copyOp.getTargetMutable().assign(castOp.getSource());
680  });
681  modified = true;
682  }
683  }
684  }
685 
686  return success(modified);
687  }
688 };
689 
690 /// Fold memref.copy(%x, %x).
691 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
693 
694  LogicalResult matchAndRewrite(CopyOp copyOp,
695  PatternRewriter &rewriter) const override {
696  if (copyOp.getSource() != copyOp.getTarget())
697  return failure();
698 
699  rewriter.eraseOp(copyOp);
700  return success();
701  }
702 };
703 } // namespace
704 
705 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
706  MLIRContext *context) {
707  results.add<FoldCopyOfCast, FoldSelfCopy>(context);
708 }
709 
710 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
712  /// copy(memrefcast) -> copy
713  bool folded = false;
714  Operation *op = *this;
715  for (OpOperand &operand : op->getOpOperands()) {
716  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
717  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
718  operand.set(castOp.getOperand());
719  folded = true;
720  }
721  }
722  return success(folded);
723 }
724 
725 //===----------------------------------------------------------------------===//
726 // DeallocOp
727 //===----------------------------------------------------------------------===//
728 
729 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
731  /// dealloc(memrefcast) -> dealloc
732  return foldMemRefCast(*this);
733 }
734 
735 //===----------------------------------------------------------------------===//
736 // DimOp
737 //===----------------------------------------------------------------------===//
738 
739 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
740  int64_t index) {
741  auto loc = result.location;
742  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
743  build(builder, result, source, indexValue);
744 }
745 
746 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
747  Value index) {
748  auto indexTy = builder.getIndexType();
749  build(builder, result, indexTy, source, index);
750 }
751 
752 Optional<int64_t> DimOp::getConstantIndex() {
753  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
754  return constantOp.getValue().cast<IntegerAttr>().getInt();
755  return {};
756 }
757 
759  // Assume unknown index to be in range.
760  Optional<int64_t> index = getConstantIndex();
761  if (!index)
762  return success();
763 
764  // Check that constant index is not knowingly out of range.
765  auto type = getSource().getType();
766  if (auto memrefType = type.dyn_cast<MemRefType>()) {
767  if (*index >= memrefType.getRank())
768  return emitOpError("index is out of range");
769  } else if (type.isa<UnrankedMemRefType>()) {
770  // Assume index to be in range.
771  } else {
772  llvm_unreachable("expected operand with memref type");
773  }
774  return success();
775 }
776 
777 /// Return a map with key being elements in `vals` and data being number of
778 /// occurences of it. Use std::map, since the `vals` here are strides and the
779 /// dynamic stride value is the same as the tombstone value for
780 /// `DenseMap<int64_t>`.
781 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
782  std::map<int64_t, unsigned> numOccurences;
783  for (auto val : vals)
784  numOccurences[val]++;
785  return numOccurences;
786 }
787 
788 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
789 /// to be a subset of `originalType` with some `1` entries erased, return the
790 /// set of indices that specifies which of the entries of `originalShape` are
791 /// dropped to obtain `reducedShape`.
792 /// This accounts for cases where there are multiple unit-dims, but only a
793 /// subset of those are dropped. For MemRefTypes these can be disambiguated
794 /// using the strides. If a dimension is dropped the stride must be dropped too.
796 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
797  ArrayRef<OpFoldResult> sizes) {
798  llvm::SmallBitVector unusedDims(originalType.getRank());
799  if (originalType.getRank() == reducedType.getRank())
800  return unusedDims;
801 
802  for (const auto &dim : llvm::enumerate(sizes))
803  if (auto attr = dim.value().dyn_cast<Attribute>())
804  if (attr.cast<IntegerAttr>().getInt() == 1)
805  unusedDims.set(dim.index());
806 
807  // Early exit for the case where the number of unused dims matches the number
808  // of ranks reduced.
809  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
810  originalType.getRank())
811  return unusedDims;
812 
813  SmallVector<int64_t> originalStrides, candidateStrides;
814  int64_t originalOffset, candidateOffset;
815  if (failed(
816  getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
817  failed(
818  getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
819  return llvm::None;
820 
821  // For memrefs, a dimension is truly dropped if its corresponding stride is
822  // also dropped. This is particularly important when more than one of the dims
823  // is 1. Track the number of occurences of the strides in the original type
824  // and the candidate type. For each unused dim that stride should not be
825  // present in the candidate type. Note that there could be multiple dimensions
826  // that have the same size. We dont need to exactly figure out which dim
827  // corresponds to which stride, we just need to verify that the number of
828  // reptitions of a stride in the original + number of unused dims with that
829  // stride == number of repititions of a stride in the candidate.
830  std::map<int64_t, unsigned> currUnaccountedStrides =
831  getNumOccurences(originalStrides);
832  std::map<int64_t, unsigned> candidateStridesNumOccurences =
833  getNumOccurences(candidateStrides);
834  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
835  if (!unusedDims.test(dim))
836  continue;
837  int64_t originalStride = originalStrides[dim];
838  if (currUnaccountedStrides[originalStride] >
839  candidateStridesNumOccurences[originalStride]) {
840  // This dim can be treated as dropped.
841  currUnaccountedStrides[originalStride]--;
842  continue;
843  }
844  if (currUnaccountedStrides[originalStride] ==
845  candidateStridesNumOccurences[originalStride]) {
846  // The stride for this is not dropped. Keep as is.
847  unusedDims.reset(dim);
848  continue;
849  }
850  if (currUnaccountedStrides[originalStride] <
851  candidateStridesNumOccurences[originalStride]) {
852  // This should never happen. Cant have a stride in the reduced rank type
853  // that wasnt in the original one.
854  return llvm::None;
855  }
856  }
857 
858  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
859  originalType.getRank())
860  return llvm::None;
861  return unusedDims;
862 }
863 
864 llvm::SmallBitVector SubViewOp::getDroppedDims() {
865  MemRefType sourceType = getSourceType();
866  MemRefType resultType = getType();
868  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
869  assert(unusedDims && "unable to find unused dims of subview");
870  return *unusedDims;
871 }
872 
873 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
874  // All forms of folding require a known index.
875  auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
876  if (!index)
877  return {};
878 
879  // Folding for unranked types (UnrankedMemRefType) is not supported.
880  auto memrefType = getSource().getType().dyn_cast<MemRefType>();
881  if (!memrefType)
882  return {};
883 
884  // Fold if the shape extent along the given index is known.
885  if (!memrefType.isDynamicDim(index.getInt())) {
886  Builder builder(getContext());
887  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
888  }
889 
890  // The size at the given index is now known to be a dynamic size.
891  unsigned unsignedIndex = index.getValue().getZExtValue();
892 
893  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
894  Operation *definingOp = getSource().getDefiningOp();
895 
896  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
897  return *(alloc.getDynamicSizes().begin() +
898  memrefType.getDynamicDimIndex(unsignedIndex));
899 
900  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
901  return *(alloca.getDynamicSizes().begin() +
902  memrefType.getDynamicDimIndex(unsignedIndex));
903 
904  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
905  return *(view.getDynamicSizes().begin() +
906  memrefType.getDynamicDimIndex(unsignedIndex));
907 
908  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
909  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
910  unsigned resultIndex = 0;
911  unsigned sourceRank = subview.getSourceType().getRank();
912  unsigned sourceIndex = 0;
913  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
914  if (unusedDims.test(i))
915  continue;
916  if (resultIndex == unsignedIndex) {
917  sourceIndex = i;
918  break;
919  }
920  resultIndex++;
921  }
922  assert(subview.isDynamicSize(sourceIndex) &&
923  "expected dynamic subview size");
924  return subview.getDynamicSize(sourceIndex);
925  }
926 
927  if (auto sizeInterface =
928  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
929  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
930  "Expected dynamic subview size");
931  return sizeInterface.getDynamicSize(unsignedIndex);
932  }
933 
934  // dim(memrefcast) -> dim
935  if (succeeded(foldMemRefCast(*this)))
936  return getResult();
937 
938  return {};
939 }
940 
941 namespace {
942 /// Fold dim of a memref reshape operation to a load into the reshape's shape
943 /// operand.
944 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
946 
947  LogicalResult matchAndRewrite(DimOp dim,
948  PatternRewriter &rewriter) const override {
949  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
950 
951  if (!reshape)
952  return failure();
953 
954  // Place the load directly after the reshape to ensure that the shape memref
955  // was not mutated.
956  rewriter.setInsertionPointAfter(reshape);
957  Location loc = dim.getLoc();
958  Value load =
959  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
960  if (load.getType() != dim.getType())
961  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
962  rewriter.replaceOp(dim, load);
963  return success();
964  }
965 };
966 
967 } // namespace
968 
969 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
970  MLIRContext *context) {
971  results.add<DimOfMemRefReshape>(context);
972 }
973 
974 // ---------------------------------------------------------------------------
975 // DmaStartOp
976 // ---------------------------------------------------------------------------
977 
978 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
979  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
980  ValueRange destIndices, Value numElements,
981  Value tagMemRef, ValueRange tagIndices, Value stride,
982  Value elementsPerStride) {
983  result.addOperands(srcMemRef);
984  result.addOperands(srcIndices);
985  result.addOperands(destMemRef);
986  result.addOperands(destIndices);
987  result.addOperands({numElements, tagMemRef});
988  result.addOperands(tagIndices);
989  if (stride)
990  result.addOperands({stride, elementsPerStride});
991 }
992 
994  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
995  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
996  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
997  if (isStrided())
998  p << ", " << getStride() << ", " << getNumElementsPerStride();
999 
1000  p.printOptionalAttrDict((*this)->getAttrs());
1001  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1002  << ", " << getTagMemRef().getType();
1003 }
1004 
1005 // Parse DmaStartOp.
1006 // Ex:
1007 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1008 // %tag[%index], %stride, %num_elt_per_stride :
1009 // : memref<3076 x f32, 0>,
1010 // memref<1024 x f32, 2>,
1011 // memref<1 x i32>
1012 //
1013 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1014  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1015  SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1016  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1017  SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1018  OpAsmParser::UnresolvedOperand numElementsInfo;
1019  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1020  SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1021  SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1022 
1023  SmallVector<Type, 3> types;
1024  auto indexType = parser.getBuilder().getIndexType();
1025 
1026  // Parse and resolve the following list of operands:
1027  // *) source memref followed by its indices (in square brackets).
1028  // *) destination memref followed by its indices (in square brackets).
1029  // *) dma size in KiB.
1030  if (parser.parseOperand(srcMemRefInfo) ||
1031  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1032  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1033  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1034  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1035  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1036  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1037  return failure();
1038 
1039  // Parse optional stride and elements per stride.
1040  if (parser.parseTrailingOperandList(strideInfo))
1041  return failure();
1042 
1043  bool isStrided = strideInfo.size() == 2;
1044  if (!strideInfo.empty() && !isStrided) {
1045  return parser.emitError(parser.getNameLoc(),
1046  "expected two stride related operands");
1047  }
1048 
1049  if (parser.parseColonTypeList(types))
1050  return failure();
1051  if (types.size() != 3)
1052  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1053 
1054  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1055  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1056  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1057  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1058  // size should be an index.
1059  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1060  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1061  // tag indices should be index.
1062  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1063  return failure();
1064 
1065  if (isStrided) {
1066  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1067  return failure();
1068  }
1069 
1070  return success();
1071 }
1072 
1074  unsigned numOperands = getNumOperands();
1075 
1076  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1077  // the number of elements.
1078  if (numOperands < 4)
1079  return emitOpError("expected at least 4 operands");
1080 
1081  // Check types of operands. The order of these calls is important: the later
1082  // calls rely on some type properties to compute the operand position.
1083  // 1. Source memref.
1084  if (!getSrcMemRef().getType().isa<MemRefType>())
1085  return emitOpError("expected source to be of memref type");
1086  if (numOperands < getSrcMemRefRank() + 4)
1087  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1088  << " operands";
1089  if (!getSrcIndices().empty() &&
1090  !llvm::all_of(getSrcIndices().getTypes(),
1091  [](Type t) { return t.isIndex(); }))
1092  return emitOpError("expected source indices to be of index type");
1093 
1094  // 2. Destination memref.
1095  if (!getDstMemRef().getType().isa<MemRefType>())
1096  return emitOpError("expected destination to be of memref type");
1097  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1098  if (numOperands < numExpectedOperands)
1099  return emitOpError() << "expected at least " << numExpectedOperands
1100  << " operands";
1101  if (!getDstIndices().empty() &&
1102  !llvm::all_of(getDstIndices().getTypes(),
1103  [](Type t) { return t.isIndex(); }))
1104  return emitOpError("expected destination indices to be of index type");
1105 
1106  // 3. Number of elements.
1107  if (!getNumElements().getType().isIndex())
1108  return emitOpError("expected num elements to be of index type");
1109 
1110  // 4. Tag memref.
1111  if (!getTagMemRef().getType().isa<MemRefType>())
1112  return emitOpError("expected tag to be of memref type");
1113  numExpectedOperands += getTagMemRefRank();
1114  if (numOperands < numExpectedOperands)
1115  return emitOpError() << "expected at least " << numExpectedOperands
1116  << " operands";
1117  if (!getTagIndices().empty() &&
1118  !llvm::all_of(getTagIndices().getTypes(),
1119  [](Type t) { return t.isIndex(); }))
1120  return emitOpError("expected tag indices to be of index type");
1121 
1122  // Optional stride-related operands must be either both present or both
1123  // absent.
1124  if (numOperands != numExpectedOperands &&
1125  numOperands != numExpectedOperands + 2)
1126  return emitOpError("incorrect number of operands");
1127 
1128  // 5. Strides.
1129  if (isStrided()) {
1130  if (!getStride().getType().isIndex() ||
1131  !getNumElementsPerStride().getType().isIndex())
1132  return emitOpError(
1133  "expected stride and num elements per stride to be of type index");
1134  }
1135 
1136  return success();
1137 }
1138 
1139 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1140  SmallVectorImpl<OpFoldResult> &results) {
1141  /// dma_start(memrefcast) -> dma_start
1142  return foldMemRefCast(*this);
1143 }
1144 
1145 // ---------------------------------------------------------------------------
1146 // DmaWaitOp
1147 // ---------------------------------------------------------------------------
1148 
1149 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1150  SmallVectorImpl<OpFoldResult> &results) {
1151  /// dma_wait(memrefcast) -> dma_wait
1152  return foldMemRefCast(*this);
1153 }
1154 
1156  // Check that the number of tag indices matches the tagMemRef rank.
1157  unsigned numTagIndices = getTagIndices().size();
1158  unsigned tagMemRefRank = getTagMemRefRank();
1159  if (numTagIndices != tagMemRefRank)
1160  return emitOpError() << "expected tagIndices to have the same number of "
1161  "elements as the tagMemRef rank, expected "
1162  << tagMemRefRank << ", but got " << numTagIndices;
1163  return success();
1164 }
1165 
1166 //===----------------------------------------------------------------------===//
1167 // GenericAtomicRMWOp
1168 //===----------------------------------------------------------------------===//
1169 
1170 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1171  Value memref, ValueRange ivs) {
1172  result.addOperands(memref);
1173  result.addOperands(ivs);
1174 
1175  if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
1176  Type elementType = memrefType.getElementType();
1177  result.addTypes(elementType);
1178 
1179  Region *bodyRegion = result.addRegion();
1180  bodyRegion->push_back(new Block());
1181  bodyRegion->addArgument(elementType, memref.getLoc());
1182  }
1183 }
1184 
1186  auto &body = getRegion();
1187  if (body.getNumArguments() != 1)
1188  return emitOpError("expected single number of entry block arguments");
1189 
1190  if (getResult().getType() != body.getArgument(0).getType())
1191  return emitOpError("expected block argument of the same type result type");
1192 
1193  bool hasSideEffects =
1194  body.walk([&](Operation *nestedOp) {
1195  if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
1196  return WalkResult::advance();
1197  nestedOp->emitError(
1198  "body of 'memref.generic_atomic_rmw' should contain "
1199  "only operations with no side effects");
1200  return WalkResult::interrupt();
1201  })
1202  .wasInterrupted();
1203  return hasSideEffects ? failure() : success();
1204 }
1205 
1206 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1207  OperationState &result) {
1209  Type memrefType;
1210  SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1211 
1212  Type indexType = parser.getBuilder().getIndexType();
1213  if (parser.parseOperand(memref) ||
1215  parser.parseColonType(memrefType) ||
1216  parser.resolveOperand(memref, memrefType, result.operands) ||
1217  parser.resolveOperands(ivs, indexType, result.operands))
1218  return failure();
1219 
1220  Region *body = result.addRegion();
1221  if (parser.parseRegion(*body, {}) ||
1222  parser.parseOptionalAttrDict(result.attributes))
1223  return failure();
1224  result.types.push_back(memrefType.cast<MemRefType>().getElementType());
1225  return success();
1226 }
1227 
1229  p << ' ' << getMemref() << "[" << getIndices()
1230  << "] : " << getMemref().getType() << ' ';
1231  p.printRegion(getRegion());
1232  p.printOptionalAttrDict((*this)->getAttrs());
1233 }
1234 
1235 //===----------------------------------------------------------------------===//
1236 // AtomicYieldOp
1237 //===----------------------------------------------------------------------===//
1238 
1240  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1241  Type resultType = getResult().getType();
1242  if (parentType != resultType)
1243  return emitOpError() << "types mismatch between yield op: " << resultType
1244  << " and its parent: " << parentType;
1245  return success();
1246 }
1247 
1248 //===----------------------------------------------------------------------===//
1249 // GlobalOp
1250 //===----------------------------------------------------------------------===//
1251 
1253  TypeAttr type,
1254  Attribute initialValue) {
1255  p << type;
1256  if (!op.isExternal()) {
1257  p << " = ";
1258  if (op.isUninitialized())
1259  p << "uninitialized";
1260  else
1261  p.printAttributeWithoutType(initialValue);
1262  }
1263 }
1264 
1265 static ParseResult
1267  Attribute &initialValue) {
1268  Type type;
1269  if (parser.parseType(type))
1270  return failure();
1271 
1272  auto memrefType = type.dyn_cast<MemRefType>();
1273  if (!memrefType || !memrefType.hasStaticShape())
1274  return parser.emitError(parser.getNameLoc())
1275  << "type should be static shaped memref, but got " << type;
1276  typeAttr = TypeAttr::get(type);
1277 
1278  if (parser.parseOptionalEqual())
1279  return success();
1280 
1281  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1282  initialValue = UnitAttr::get(parser.getContext());
1283  return success();
1284  }
1285 
1286  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1287  if (parser.parseAttribute(initialValue, tensorType))
1288  return failure();
1289  if (!initialValue.isa<ElementsAttr>())
1290  return parser.emitError(parser.getNameLoc())
1291  << "initial value should be a unit or elements attribute";
1292  return success();
1293 }
1294 
1296  auto memrefType = getType().dyn_cast<MemRefType>();
1297  if (!memrefType || !memrefType.hasStaticShape())
1298  return emitOpError("type should be static shaped memref, but got ")
1299  << getType();
1300 
1301  // Verify that the initial value, if present, is either a unit attribute or
1302  // an elements attribute.
1303  if (getInitialValue().has_value()) {
1304  Attribute initValue = getInitialValue().value();
1305  if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1306  return emitOpError("initial value should be a unit or elements "
1307  "attribute, but got ")
1308  << initValue;
1309 
1310  // Check that the type of the initial value is compatible with the type of
1311  // the global variable.
1312  if (auto elementsAttr = initValue.dyn_cast<ElementsAttr>()) {
1313  Type initType = elementsAttr.getType();
1314  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1315  if (initType != tensorType)
1316  return emitOpError("initial value expected to be of type ")
1317  << tensorType << ", but was of type " << initType;
1318  }
1319  }
1320 
1321  if (Optional<uint64_t> alignAttr = getAlignment()) {
1322  uint64_t alignment = *alignAttr;
1323 
1324  if (!llvm::isPowerOf2_64(alignment))
1325  return emitError() << "alignment attribute value " << alignment
1326  << " is not a power of 2";
1327  }
1328 
1329  // TODO: verify visibility for declarations.
1330  return success();
1331 }
1332 
1333 ElementsAttr GlobalOp::getConstantInitValue() {
1334  auto initVal = getInitialValue();
1335  if (getConstant() && initVal.has_value())
1336  return initVal.value().cast<ElementsAttr>();
1337  return {};
1338 }
1339 
1340 //===----------------------------------------------------------------------===//
1341 // GetGlobalOp
1342 //===----------------------------------------------------------------------===//
1343 
1345 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1346  // Verify that the result type is same as the type of the referenced
1347  // memref.global op.
1348  auto global =
1349  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1350  if (!global)
1351  return emitOpError("'")
1352  << getName() << "' does not reference a valid global memref";
1353 
1354  Type resultType = getResult().getType();
1355  if (global.getType() != resultType)
1356  return emitOpError("result type ")
1357  << resultType << " does not match type " << global.getType()
1358  << " of the global memref @" << getName();
1359  return success();
1360 }
1361 
1362 //===----------------------------------------------------------------------===//
1363 // LoadOp
1364 //===----------------------------------------------------------------------===//
1365 
1367  if (getNumOperands() != 1 + getMemRefType().getRank())
1368  return emitOpError("incorrect number of indices for load");
1369  return success();
1370 }
1371 
1372 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1373  /// load(memrefcast) -> load
1374  if (succeeded(foldMemRefCast(*this)))
1375  return getResult();
1376  return OpFoldResult();
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // PrefetchOp
1381 //===----------------------------------------------------------------------===//
1382 
1384  p << " " << getMemref() << '[';
1386  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1387  p << ", locality<" << getLocalityHint();
1388  p << ">, " << (getIsDataCache() ? "data" : "instr");
1390  (*this)->getAttrs(),
1391  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1392  p << " : " << getMemRefType();
1393 }
1394 
1395 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1396  OpAsmParser::UnresolvedOperand memrefInfo;
1397  SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1398  IntegerAttr localityHint;
1399  MemRefType type;
1400  StringRef readOrWrite, cacheType;
1401 
1402  auto indexTy = parser.getBuilder().getIndexType();
1403  auto i32Type = parser.getBuilder().getIntegerType(32);
1404  if (parser.parseOperand(memrefInfo) ||
1405  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1406  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1407  parser.parseComma() || parser.parseKeyword("locality") ||
1408  parser.parseLess() ||
1409  parser.parseAttribute(localityHint, i32Type, "localityHint",
1410  result.attributes) ||
1411  parser.parseGreater() || parser.parseComma() ||
1412  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1413  parser.resolveOperand(memrefInfo, type, result.operands) ||
1414  parser.resolveOperands(indexInfo, indexTy, result.operands))
1415  return failure();
1416 
1417  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1418  return parser.emitError(parser.getNameLoc(),
1419  "rw specifier has to be 'read' or 'write'");
1420  result.addAttribute(
1421  PrefetchOp::getIsWriteAttrStrName(),
1422  parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1423 
1424  if (!cacheType.equals("data") && !cacheType.equals("instr"))
1425  return parser.emitError(parser.getNameLoc(),
1426  "cache type has to be 'data' or 'instr'");
1427 
1428  result.addAttribute(
1429  PrefetchOp::getIsDataCacheAttrStrName(),
1430  parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1431 
1432  return success();
1433 }
1434 
1436  if (getNumOperands() != 1 + getMemRefType().getRank())
1437  return emitOpError("too few indices");
1438 
1439  return success();
1440 }
1441 
1442 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1443  SmallVectorImpl<OpFoldResult> &results) {
1444  // prefetch(memrefcast) -> prefetch
1445  return foldMemRefCast(*this);
1446 }
1447 
1448 //===----------------------------------------------------------------------===//
1449 // RankOp
1450 //===----------------------------------------------------------------------===//
1451 
1452 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1453  // Constant fold rank when the rank of the operand is known.
1454  auto type = getOperand().getType();
1455  auto shapedType = type.dyn_cast<ShapedType>();
1456  if (shapedType && shapedType.hasRank())
1457  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1458  return IntegerAttr();
1459 }
1460 
1461 //===----------------------------------------------------------------------===//
1462 // ReinterpretCastOp
1463 //===----------------------------------------------------------------------===//
1464 
1465 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1466 /// `staticSizes` and `staticStrides` are automatically filled with
1467 /// source-memref-rank sentinel values that encode dynamic entries.
1468 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1469  MemRefType resultType, Value source,
1470  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1471  ArrayRef<OpFoldResult> strides,
1472  ArrayRef<NamedAttribute> attrs) {
1473  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1474  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1475  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1476  ShapedType::kDynamicStrideOrOffset);
1477  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1478  ShapedType::kDynamicSize);
1479  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1480  ShapedType::kDynamicStrideOrOffset);
1481  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1482  dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1483  b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1484  result.addAttributes(attrs);
1485 }
1486 
1487 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1488  MemRefType resultType, Value source,
1489  int64_t offset, ArrayRef<int64_t> sizes,
1490  ArrayRef<int64_t> strides,
1491  ArrayRef<NamedAttribute> attrs) {
1492  SmallVector<OpFoldResult> sizeValues =
1493  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1494  return b.getI64IntegerAttr(v);
1495  }));
1496  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1497  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1498  return b.getI64IntegerAttr(v);
1499  }));
1500  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1501  strideValues, attrs);
1502 }
1503 
1504 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1505  MemRefType resultType, Value source, Value offset,
1506  ValueRange sizes, ValueRange strides,
1507  ArrayRef<NamedAttribute> attrs) {
1508  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1509  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1510  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1511  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1512  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1513 }
1514 
1515 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1516 // completed automatically, like we have for subview and extract_slice.
1518  // The source and result memrefs should be in the same memory space.
1519  auto srcType = getSource().getType().cast<BaseMemRefType>();
1520  auto resultType = getType().cast<MemRefType>();
1521  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1522  return emitError("different memory spaces specified for source type ")
1523  << srcType << " and result memref type " << resultType;
1524  if (srcType.getElementType() != resultType.getElementType())
1525  return emitError("different element types specified for source type ")
1526  << srcType << " and result memref type " << resultType;
1527 
1528  // Match sizes in result memref type and in static_sizes attribute.
1529  for (auto &en : llvm::enumerate(llvm::zip(
1530  resultType.getShape(), extractFromI64ArrayAttr(getStaticSizes())))) {
1531  int64_t resultSize = std::get<0>(en.value());
1532  int64_t expectedSize = std::get<1>(en.value());
1533  if (!ShapedType::isDynamic(resultSize) &&
1534  !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1535  return emitError("expected result type with size = ")
1536  << expectedSize << " instead of " << resultSize
1537  << " in dim = " << en.index();
1538  }
1539 
1540  // Match offset and strides in static_offset and static_strides attributes. If
1541  // result memref type has no affine map specified, this will assume an
1542  // identity layout.
1543  int64_t resultOffset;
1544  SmallVector<int64_t, 4> resultStrides;
1545  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1546  return emitError("expected result type to have strided layout but found ")
1547  << resultType;
1548 
1549  // Match offset in result memref type and in static_offsets attribute.
1550  int64_t expectedOffset = extractFromI64ArrayAttr(getStaticOffsets()).front();
1551  if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
1552  !ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
1553  resultOffset != expectedOffset)
1554  return emitError("expected result type with offset = ")
1555  << resultOffset << " instead of " << expectedOffset;
1556 
1557  // Match strides in result memref type and in static_strides attribute.
1558  for (auto &en : llvm::enumerate(llvm::zip(
1559  resultStrides, extractFromI64ArrayAttr(getStaticStrides())))) {
1560  int64_t resultStride = std::get<0>(en.value());
1561  int64_t expectedStride = std::get<1>(en.value());
1562  if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
1563  !ShapedType::isDynamicStrideOrOffset(expectedStride) &&
1564  resultStride != expectedStride)
1565  return emitError("expected result type with stride = ")
1566  << expectedStride << " instead of " << resultStride
1567  << " in dim = " << en.index();
1568  }
1569 
1570  return success();
1571 }
1572 
1573 OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
1574  Value src = getSource();
1575  auto getPrevSrc = [&]() -> Value {
1576  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1577  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1578  return prev.getSource();
1579 
1580  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1581  if (auto prev = src.getDefiningOp<CastOp>())
1582  return prev.getSource();
1583 
1584  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1585  // are 0.
1586  if (auto prev = src.getDefiningOp<SubViewOp>())
1587  if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1588  return isConstantIntValue(val, 0);
1589  }))
1590  return prev.getSource();
1591 
1592  return nullptr;
1593  };
1594 
1595  if (auto prevSrc = getPrevSrc()) {
1596  getSourceMutable().assign(prevSrc);
1597  return getResult();
1598  }
1599 
1600  return nullptr;
1601 }
1602 
1603 //===----------------------------------------------------------------------===//
1604 // Reassociative reshape ops
1605 //===----------------------------------------------------------------------===//
1606 
1607 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
1608 /// result and operand. Layout maps are verified separately.
1609 ///
1610 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
1611 /// allowed in a reassocation group.
1612 static LogicalResult
1614  ArrayRef<int64_t> expandedShape,
1615  ArrayRef<ReassociationIndices> reassociation,
1616  bool allowMultipleDynamicDimsPerGroup) {
1617  // There must be one reassociation group per collapsed dimension.
1618  if (collapsedShape.size() != reassociation.size())
1619  return op->emitOpError("invalid number of reassociation groups: found ")
1620  << reassociation.size() << ", expected " << collapsedShape.size();
1621 
1622  // The next expected expanded dimension index (while iterating over
1623  // reassociation indices).
1624  int64_t nextDim = 0;
1625  for (const auto &it : llvm::enumerate(reassociation)) {
1626  ReassociationIndices group = it.value();
1627  int64_t collapsedDim = it.index();
1628 
1629  bool foundDynamic = false;
1630  for (int64_t expandedDim : group) {
1631  if (expandedDim != nextDim++)
1632  return op->emitOpError("reassociation indices must be contiguous");
1633 
1634  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
1635  return op->emitOpError("reassociation index ")
1636  << expandedDim << " is out of bounds";
1637 
1638  // Check if there are multiple dynamic dims in a reassociation group.
1639  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
1640  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
1641  return op->emitOpError(
1642  "at most one dimension in a reassociation group may be dynamic");
1643  foundDynamic = true;
1644  }
1645  }
1646 
1647  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
1648  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
1649  return op->emitOpError("collapsed dim (")
1650  << collapsedDim
1651  << ") must be dynamic if and only if reassociation group is "
1652  "dynamic";
1653 
1654  // If all dims in the reassociation group are static, the size of the
1655  // collapsed dim can be verified.
1656  if (!foundDynamic) {
1657  int64_t groupSize = 1;
1658  for (int64_t expandedDim : group)
1659  groupSize *= expandedShape[expandedDim];
1660  if (groupSize != collapsedShape[collapsedDim])
1661  return op->emitOpError("collapsed dim size (")
1662  << collapsedShape[collapsedDim]
1663  << ") must equal reassociation group size (" << groupSize << ")";
1664  }
1665  }
1666 
1667  if (collapsedShape.empty()) {
1668  // Rank 0: All expanded dimensions must be 1.
1669  for (int64_t d : expandedShape)
1670  if (d != 1)
1671  return op->emitOpError(
1672  "rank 0 memrefs can only be extended/collapsed with/from ones");
1673  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
1674  // Rank >= 1: Number of dimensions among all reassociation groups must match
1675  // the result memref rank.
1676  return op->emitOpError("expanded rank (")
1677  << expandedShape.size()
1678  << ") inconsistent with number of reassociation indices (" << nextDim
1679  << ")";
1680  }
1681 
1682  return success();
1683 }
1684 
1685 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1686  return getSymbolLessAffineMaps(getReassociationExprs());
1687 }
1688 
1689 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1690  return convertReassociationIndicesToExprs(getContext(),
1691  getReassociationIndices());
1692 }
1693 
1694 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1695  return getSymbolLessAffineMaps(getReassociationExprs());
1696 }
1697 
1698 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1699  return convertReassociationIndicesToExprs(getContext(),
1700  getReassociationIndices());
1701 }
1702 
1703 /// Compute the layout map after expanding a given source MemRef type with the
1704 /// specified reassociation indices.
1705 static FailureOr<AffineMap>
1706 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
1707  ArrayRef<ReassociationIndices> reassociation) {
1708  int64_t srcOffset;
1709  SmallVector<int64_t> srcStrides;
1710  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1711  return failure();
1712  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
1713 
1714  // 1-1 mapping between srcStrides and reassociation packs.
1715  // Each srcStride starts with the given value and gets expanded according to
1716  // the proper entries in resultShape.
1717  // Example:
1718  // srcStrides = [10000, 1 , 100 ],
1719  // reassociations = [ [0], [1], [2, 3, 4]],
1720  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
1721  // -> For the purpose of stride calculation, the useful sizes are:
1722  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
1723  // resultStrides = [10000, 1, 600, 200, 100]
1724  // Note that a stride does not get expanded along the first entry of each
1725  // shape pack.
1726  SmallVector<int64_t> reverseResultStrides;
1727  reverseResultStrides.reserve(resultShape.size());
1728  unsigned shapeIndex = resultShape.size() - 1;
1729  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
1730  ReassociationIndices reassoc = std::get<0>(it);
1731  int64_t currentStrideToExpand = std::get<1>(it);
1732  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
1733  using saturated_arith::Wrapper;
1734  reverseResultStrides.push_back(currentStrideToExpand);
1735  currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
1736  Wrapper::size(resultShape[shapeIndex--]))
1737  .asStride();
1738  }
1739  }
1740  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
1741  resultStrides.resize(resultShape.size(), 1);
1742  return makeStridedLinearLayoutMap(resultStrides, srcOffset,
1743  srcType.getContext());
1744 }
1745 
1746 static FailureOr<MemRefType>
1747 computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape,
1748  ArrayRef<ReassociationIndices> reassociation) {
1749  if (srcType.getLayout().isIdentity()) {
1750  // If the source is contiguous (i.e., no layout map specified), so is the
1751  // result.
1752  MemRefLayoutAttrInterface layout;
1753  return MemRefType::get(resultShape, srcType.getElementType(), layout,
1754  srcType.getMemorySpace());
1755  }
1756 
1757  // Source may not be contiguous. Compute the layout map.
1758  FailureOr<AffineMap> computedLayout =
1759  computeExpandedLayoutMap(srcType, resultShape, reassociation);
1760  if (failed(computedLayout))
1761  return failure();
1762  auto computedType =
1763  MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1764  srcType.getMemorySpaceAsInt());
1765  return canonicalizeStridedLayout(computedType);
1766 }
1767 
1768 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1769  ArrayRef<int64_t> resultShape, Value src,
1770  ArrayRef<ReassociationIndices> reassociation) {
1771  // Only ranked memref source values are supported.
1772  auto srcType = src.getType().cast<MemRefType>();
1773  FailureOr<MemRefType> resultType =
1774  computeExpandedType(srcType, resultShape, reassociation);
1775  // Failure of this assertion usually indicates a problem with the source
1776  // type, e.g., could not get strides/offset.
1777  assert(succeeded(resultType) && "could not compute layout");
1778  build(builder, result, *resultType, src, reassociation);
1779 }
1780 
1782  MemRefType srcType = getSrcType();
1783  MemRefType resultType = getResultType();
1784 
1785  // Verify result shape.
1786  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
1787  resultType.getShape(),
1788  getReassociationIndices(),
1789  /*allowMultipleDynamicDimsPerGroup=*/false)))
1790  return failure();
1791 
1792  // Compute expected result type (including layout map).
1793  FailureOr<MemRefType> expectedResultType = computeExpandedType(
1794  srcType, resultType.getShape(), getReassociationIndices());
1795  if (failed(expectedResultType))
1796  return emitOpError("invalid source layout map");
1797 
1798  // Check actual result type.
1799  auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
1800  if (*expectedResultType != canonicalizedResultType)
1801  return emitOpError("expected expanded type to be ")
1802  << *expectedResultType << " but found " << canonicalizedResultType;
1803 
1804  return success();
1805 }
1806 
1807 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1808  MLIRContext *context) {
1811  context);
1812 }
1813 
1814 /// Compute the layout map after collapsing a given source MemRef type with the
1815 /// specified reassociation indices.
1816 ///
1817 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
1818 /// not possible to check this by inspecting a MemRefType in the general case.
1819 /// If non-contiguity cannot be checked statically, the collapse is assumed to
1820 /// be valid (and thus accepted by this function) unless `strict = true`.
1821 static FailureOr<AffineMap>
1822 computeCollapsedLayoutMap(MemRefType srcType,
1823  ArrayRef<ReassociationIndices> reassociation,
1824  bool strict = false) {
1825  int64_t srcOffset;
1826  SmallVector<int64_t> srcStrides;
1827  auto srcShape = srcType.getShape();
1828  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1829  return failure();
1830 
1831  // The result stride of a reassociation group is the stride of the last entry
1832  // of the reassociation. (TODO: Should be the minimum stride in the
1833  // reassociation because strides are not necessarily sorted. E.g., when using
1834  // memref.transpose.) Dimensions of size 1 should be skipped, because their
1835  // strides are meaningless and could have any arbitrary value.
1836  SmallVector<int64_t> resultStrides;
1837  resultStrides.reserve(reassociation.size());
1838  for (const ReassociationIndices &reassoc : reassociation) {
1839  ArrayRef<int64_t> ref = llvm::makeArrayRef(reassoc);
1840  while (srcShape[ref.back()] == 1 && ref.size() > 1)
1841  ref = ref.drop_back();
1842  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1843  resultStrides.push_back(srcStrides[ref.back()]);
1844  } else {
1845  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
1846  // the corresponding stride may have to be skipped. (See above comment.)
1847  // Therefore, the result stride cannot be statically determined and must
1848  // be dynamic.
1849  resultStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1850  }
1851  }
1852 
1853  // Validate that each reassociation group is contiguous.
1854  unsigned resultStrideIndex = resultStrides.size() - 1;
1855  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
1856  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
1857  using saturated_arith::Wrapper;
1858  auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
1859  for (int64_t idx : llvm::reverse(trailingReassocs)) {
1860  stride = stride * Wrapper::size(srcShape[idx]);
1861 
1862  // Both source and result stride must have the same static value. In that
1863  // case, we can be sure, that the dimensions are collapsible (because they
1864  // are contiguous).
1865  //
1866  // One special case is when the srcShape is `1`, in which case it can
1867  // never produce non-contiguity.
1868  if (srcShape[idx] == 1)
1869  continue;
1870 
1871  // If `strict = false` (default during op verification), we accept cases
1872  // where one or both strides are dynamic. This is best effort: We reject
1873  // ops where obviously non-contiguous dims are collapsed, but accept ops
1874  // where we cannot be sure statically. Such ops may fail at runtime. See
1875  // the op documentation for details.
1876  auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
1877  if (strict && (stride.saturated || srcStride.saturated))
1878  return failure();
1879 
1880  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
1881  return failure();
1882  }
1883  }
1884  return makeStridedLinearLayoutMap(resultStrides, srcOffset,
1885  srcType.getContext());
1886 }
1887 
1888 bool CollapseShapeOp::isGuaranteedCollapsible(
1889  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
1890  // MemRefs with standard layout are always collapsible.
1891  if (srcType.getLayout().isIdentity())
1892  return true;
1893 
1894  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
1895  /*strict=*/true));
1896 }
1897 
1898 static MemRefType
1899 computeCollapsedType(MemRefType srcType,
1900  ArrayRef<ReassociationIndices> reassociation) {
1901  SmallVector<int64_t> resultShape;
1902  resultShape.reserve(reassociation.size());
1903  for (const ReassociationIndices &group : reassociation) {
1904  using saturated_arith::Wrapper;
1905  auto groupSize = Wrapper::size(1);
1906  for (int64_t srcDim : group)
1907  groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
1908  resultShape.push_back(groupSize.asSize());
1909  }
1910 
1911  if (srcType.getLayout().isIdentity()) {
1912  // If the source is contiguous (i.e., no layout map specified), so is the
1913  // result.
1914  MemRefLayoutAttrInterface layout;
1915  return MemRefType::get(resultShape, srcType.getElementType(), layout,
1916  srcType.getMemorySpace());
1917  }
1918 
1919  // Source may not be fully contiguous. Compute the layout map.
1920  // Note: Dimensions that are collapsed into a single dim are assumed to be
1921  // contiguous.
1922  FailureOr<AffineMap> computedLayout =
1923  computeCollapsedLayoutMap(srcType, reassociation);
1924  assert(succeeded(computedLayout) &&
1925  "invalid source layout map or collapsing non-contiguous dims");
1926  auto computedType =
1927  MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1928  srcType.getMemorySpaceAsInt());
1929  return canonicalizeStridedLayout(computedType);
1930 }
1931 
1932 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1933  ArrayRef<ReassociationIndices> reassociation,
1934  ArrayRef<NamedAttribute> attrs) {
1935  auto srcType = src.getType().cast<MemRefType>();
1936  MemRefType resultType = computeCollapsedType(srcType, reassociation);
1937  build(b, result, resultType, src, attrs);
1939  getReassociationIndicesAttribute(b, reassociation));
1940 }
1941 
1943  MemRefType srcType = getSrcType();
1944  MemRefType resultType = getResultType();
1945 
1946  // Verify result shape.
1947  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
1948  srcType.getShape(), getReassociationIndices(),
1949  /*allowMultipleDynamicDimsPerGroup=*/true)))
1950  return failure();
1951 
1952  // Compute expected result type (including layout map).
1953  MemRefType expectedResultType;
1954  if (srcType.getLayout().isIdentity()) {
1955  // If the source is contiguous (i.e., no layout map specified), so is the
1956  // result.
1957  MemRefLayoutAttrInterface layout;
1958  expectedResultType =
1959  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
1960  srcType.getMemorySpace());
1961  } else {
1962  // Source may not be fully contiguous. Compute the layout map.
1963  // Note: Dimensions that are collapsed into a single dim are assumed to be
1964  // contiguous.
1965  FailureOr<AffineMap> computedLayout =
1966  computeCollapsedLayoutMap(srcType, getReassociationIndices());
1967  if (failed(computedLayout))
1968  return emitOpError(
1969  "invalid source layout map or collapsing non-contiguous dims");
1970  auto computedType =
1971  MemRefType::get(resultType.getShape(), srcType.getElementType(),
1972  *computedLayout, srcType.getMemorySpaceAsInt());
1973  expectedResultType = canonicalizeStridedLayout(computedType);
1974  }
1975 
1976  auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
1977  if (expectedResultType != canonicalizedResultType)
1978  return emitOpError("expected collapsed type to be ")
1979  << expectedResultType << " but found " << canonicalizedResultType;
1980 
1981  return success();
1982 }
1983 
1985  : public OpRewritePattern<CollapseShapeOp> {
1986 public:
1988 
1989  LogicalResult matchAndRewrite(CollapseShapeOp op,
1990  PatternRewriter &rewriter) const override {
1991  auto cast = op.getOperand().getDefiningOp<CastOp>();
1992  if (!cast)
1993  return failure();
1994 
1995  if (!CastOp::canFoldIntoConsumerOp(cast))
1996  return failure();
1997 
1998  Type newResultType =
1999  computeCollapsedType(cast.getOperand().getType().cast<MemRefType>(),
2000  op.getReassociationIndices());
2001 
2002  if (newResultType == op.getResultType()) {
2003  rewriter.updateRootInPlace(
2004  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2005  } else {
2006  Value newOp = rewriter.create<CollapseShapeOp>(
2007  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2008  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2009  }
2010  return success();
2011  }
2012 };
2013 
2014 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2015  MLIRContext *context) {
2019 }
2020 
2021 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
2022  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
2023 }
2024 
2025 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
2026  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
2027 }
2028 
2029 //===----------------------------------------------------------------------===//
2030 // ReshapeOp
2031 //===----------------------------------------------------------------------===//
2032 
2034  Type operandType = getSource().getType();
2035  Type resultType = getResult().getType();
2036 
2037  Type operandElementType = operandType.cast<ShapedType>().getElementType();
2038  Type resultElementType = resultType.cast<ShapedType>().getElementType();
2039  if (operandElementType != resultElementType)
2040  return emitOpError("element types of source and destination memref "
2041  "types should be the same");
2042 
2043  if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
2044  if (!operandMemRefType.getLayout().isIdentity())
2045  return emitOpError("source memref type should have identity affine map");
2046 
2047  int64_t shapeSize = getShape().getType().cast<MemRefType>().getDimSize(0);
2048  auto resultMemRefType = resultType.dyn_cast<MemRefType>();
2049  if (resultMemRefType) {
2050  if (!resultMemRefType.getLayout().isIdentity())
2051  return emitOpError("result memref type should have identity affine map");
2052  if (shapeSize == ShapedType::kDynamicSize)
2053  return emitOpError("cannot use shape operand with dynamic length to "
2054  "reshape to statically-ranked memref type");
2055  if (shapeSize != resultMemRefType.getRank())
2056  return emitOpError(
2057  "length of shape operand differs from the result's memref rank");
2058  }
2059  return success();
2060 }
2061 
2062 //===----------------------------------------------------------------------===//
2063 // StoreOp
2064 //===----------------------------------------------------------------------===//
2065 
2067  if (getNumOperands() != 2 + getMemRefType().getRank())
2068  return emitOpError("store index operand count not equal to memref rank");
2069 
2070  return success();
2071 }
2072 
2073 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
2074  SmallVectorImpl<OpFoldResult> &results) {
2075  /// store(memrefcast) -> store
2076  return foldMemRefCast(*this, getValueToStore());
2077 }
2078 
2079 //===----------------------------------------------------------------------===//
2080 // SubViewOp
2081 //===----------------------------------------------------------------------===//
2082 
2083 /// A subview result type can be fully inferred from the source type and the
2084 /// static representation of offsets, sizes and strides. Special sentinels
2085 /// encode the dynamic case.
2086 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2087  ArrayRef<int64_t> staticOffsets,
2088  ArrayRef<int64_t> staticSizes,
2089  ArrayRef<int64_t> staticStrides) {
2090  unsigned rank = sourceMemRefType.getRank();
2091  (void)rank;
2092  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2093  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2094  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2095 
2096  // Extract source offset and strides.
2097  int64_t sourceOffset;
2098  SmallVector<int64_t, 4> sourceStrides;
2099  auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
2100  assert(succeeded(res) && "SubViewOp expected strided memref type");
2101  (void)res;
2102 
2103  // Compute target offset whose value is:
2104  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2105  int64_t targetOffset = sourceOffset;
2106  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2107  auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2108  using saturated_arith::Wrapper;
2109  targetOffset =
2110  (Wrapper::offset(targetOffset) +
2111  Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
2112  .asOffset();
2113  }
2114 
2115  // Compute target stride whose value is:
2116  // `sourceStrides_i * staticStrides_i`.
2117  SmallVector<int64_t, 4> targetStrides;
2118  targetStrides.reserve(staticOffsets.size());
2119  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2120  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2121  using saturated_arith::Wrapper;
2122  targetStrides.push_back(
2123  (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
2124  .asStride());
2125  }
2126 
2127  // The type is now known.
2128  return MemRefType::get(
2129  staticSizes, sourceMemRefType.getElementType(),
2130  makeStridedLinearLayoutMap(targetStrides, targetOffset,
2131  sourceMemRefType.getContext()),
2132  sourceMemRefType.getMemorySpace());
2133 }
2134 
2135 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2136  ArrayRef<OpFoldResult> offsets,
2137  ArrayRef<OpFoldResult> sizes,
2138  ArrayRef<OpFoldResult> strides) {
2139  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2140  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2141  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2142  ShapedType::kDynamicStrideOrOffset);
2143  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2144  ShapedType::kDynamicSize);
2145  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2146  ShapedType::kDynamicStrideOrOffset);
2147  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2148  staticSizes, staticStrides);
2149 }
2150 
2151 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2152  MemRefType sourceRankedTensorType,
2153  ArrayRef<int64_t> offsets,
2154  ArrayRef<int64_t> sizes,
2155  ArrayRef<int64_t> strides) {
2156  auto inferredType =
2157  inferResultType(sourceRankedTensorType, offsets, sizes, strides)
2158  .cast<MemRefType>();
2159  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2160  "expected ");
2161  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2162  return inferredType;
2163 
2164  // Compute which dimensions are dropped.
2166  computeRankReductionMask(inferredType.getShape(), resultShape);
2167  assert(dimsToProject.has_value() && "invalid rank reduction");
2168  llvm::SmallBitVector dimsToProjectVector(inferredType.getRank());
2169  for (unsigned dim : *dimsToProject)
2170  dimsToProjectVector.set(dim);
2171 
2172  // Compute layout map and result type.
2173  AffineMap map = getProjectedMap(inferredType.getLayout().getAffineMap(),
2174  dimsToProjectVector);
2175  return MemRefType::get(resultShape, inferredType.getElementType(), map,
2176  inferredType.getMemorySpace());
2177 }
2178 
2179 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2180  MemRefType sourceRankedTensorType,
2181  ArrayRef<OpFoldResult> offsets,
2182  ArrayRef<OpFoldResult> sizes,
2183  ArrayRef<OpFoldResult> strides) {
2184  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2185  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2186  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2187  ShapedType::kDynamicStrideOrOffset);
2188  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2189  ShapedType::kDynamicSize);
2190  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2191  ShapedType::kDynamicStrideOrOffset);
2192  return SubViewOp::inferRankReducedResultType(
2193  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2194  staticStrides);
2195 }
2196 
2197 // Build a SubViewOp with mixed static and dynamic entries and custom result
2198 // type. If the type passed is nullptr, it is inferred.
2199 void SubViewOp::build(OpBuilder &b, OperationState &result,
2200  MemRefType resultType, Value source,
2201  ArrayRef<OpFoldResult> offsets,
2202  ArrayRef<OpFoldResult> sizes,
2203  ArrayRef<OpFoldResult> strides,
2204  ArrayRef<NamedAttribute> attrs) {
2205  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2206  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2207  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2208  ShapedType::kDynamicStrideOrOffset);
2209  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2210  ShapedType::kDynamicSize);
2211  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2212  ShapedType::kDynamicStrideOrOffset);
2213  auto sourceMemRefType = source.getType().cast<MemRefType>();
2214  // Structuring implementation this way avoids duplication between builders.
2215  if (!resultType) {
2216  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2217  staticSizes, staticStrides)
2218  .cast<MemRefType>();
2219  }
2220  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2221  dynamicStrides, b.getI64ArrayAttr(staticOffsets),
2222  b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
2223  result.addAttributes(attrs);
2224 }
2225 
2226 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2227 // type.
2228 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2229  ArrayRef<OpFoldResult> offsets,
2230  ArrayRef<OpFoldResult> sizes,
2231  ArrayRef<OpFoldResult> strides,
2232  ArrayRef<NamedAttribute> attrs) {
2233  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2234 }
2235 
2236 // Build a SubViewOp with static entries and inferred result type.
2237 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2238  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2239  ArrayRef<int64_t> strides,
2240  ArrayRef<NamedAttribute> attrs) {
2241  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2242  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2243  return b.getI64IntegerAttr(v);
2244  }));
2245  SmallVector<OpFoldResult> sizeValues =
2246  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2247  return b.getI64IntegerAttr(v);
2248  }));
2249  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2250  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2251  return b.getI64IntegerAttr(v);
2252  }));
2253  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2254 }
2255 
2256 // Build a SubViewOp with dynamic entries and custom result type. If the
2257 // type passed is nullptr, it is inferred.
2258 void SubViewOp::build(OpBuilder &b, OperationState &result,
2259  MemRefType resultType, Value source,
2260  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2261  ArrayRef<int64_t> strides,
2262  ArrayRef<NamedAttribute> attrs) {
2263  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2264  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2265  return b.getI64IntegerAttr(v);
2266  }));
2267  SmallVector<OpFoldResult> sizeValues =
2268  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2269  return b.getI64IntegerAttr(v);
2270  }));
2271  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2272  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2273  return b.getI64IntegerAttr(v);
2274  }));
2275  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2276  attrs);
2277 }
2278 
2279 // Build a SubViewOp with dynamic entries and custom result type. If the type
2280 // passed is nullptr, it is inferred.
2281 void SubViewOp::build(OpBuilder &b, OperationState &result,
2282  MemRefType resultType, Value source, ValueRange offsets,
2283  ValueRange sizes, ValueRange strides,
2284  ArrayRef<NamedAttribute> attrs) {
2285  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2286  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2287  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2288  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2289  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2290  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2291  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2292 }
2293 
2294 // Build a SubViewOp with dynamic entries and inferred result type.
2295 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2296  ValueRange offsets, ValueRange sizes, ValueRange strides,
2297  ArrayRef<NamedAttribute> attrs) {
2298  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2299 }
2300 
2301 /// For ViewLikeOpInterface.
2302 Value SubViewOp::getViewSource() { return getSource(); }
2303 
2304 /// Return true if t1 and t2 have equal offsets (both dynamic or of same
2305 /// static value).
2306 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2307  AffineExpr t1Offset, t2Offset;
2308  SmallVector<AffineExpr> t1Strides, t2Strides;
2309  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2310  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2311  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2312 }
2313 
2314 /// Checks if `original` Type type can be rank reduced to `reduced` type.
2315 /// This function is slight variant of `is subsequence` algorithm where
2316 /// not matching dimension must be 1.
2318 isRankReducedMemRefType(MemRefType originalType,
2319  MemRefType candidateRankReducedType,
2320  ArrayRef<OpFoldResult> sizes) {
2321  auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
2322  if (partialRes != SliceVerificationResult::Success)
2323  return partialRes;
2324 
2325  auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
2326  originalType, candidateRankReducedType, sizes);
2327 
2328  // Sizes cannot be matched in case empty vector is returned.
2329  if (!optionalUnusedDimsMask)
2331 
2332  if (originalType.getMemorySpace() !=
2333  candidateRankReducedType.getMemorySpace())
2335 
2336  // No amount of stride dropping can reconcile incompatible offsets.
2337  if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
2339 
2341 }
2342 
2343 template <typename OpTy>
2345  OpTy op, Type expectedType) {
2346  auto memrefType = expectedType.cast<ShapedType>();
2347  switch (result) {
2349  return success();
2351  return op.emitError("expected result rank to be smaller or equal to ")
2352  << "the source rank. ";
2354  return op.emitError("expected result type to be ")
2355  << expectedType
2356  << " or a rank-reduced version. (mismatch of result sizes) ";
2358  return op.emitError("expected result element type to be ")
2359  << memrefType.getElementType();
2361  return op.emitError("expected result and source memory spaces to match.");
2363  return op.emitError("expected result type to be ")
2364  << expectedType
2365  << " or a rank-reduced version. (mismatch of result layout) ";
2366  }
2367  llvm_unreachable("unexpected subview verification result");
2368 }
2369 
2370 /// Verifier for SubViewOp.
2372  MemRefType baseType = getSourceType();
2373  MemRefType subViewType = getType();
2374 
2375  // The base memref and the view memref should be in the same memory space.
2376  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2377  return emitError("different memory spaces specified for base memref "
2378  "type ")
2379  << baseType << " and subview memref type " << subViewType;
2380 
2381  // Verify that the base memref type has a strided layout map.
2382  if (!isStrided(baseType))
2383  return emitError("base type ") << baseType << " is not strided";
2384 
2385  // Verify result type against inferred type.
2386  auto expectedType = SubViewOp::inferResultType(
2387  baseType, extractFromI64ArrayAttr(getStaticOffsets()),
2388  extractFromI64ArrayAttr(getStaticSizes()),
2389  extractFromI64ArrayAttr(getStaticStrides()));
2390 
2391  auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
2392  subViewType, getMixedSizes());
2393  return produceSubViewErrorMsg(result, *this, expectedType);
2394 }
2395 
2396 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2397  return os << "range " << range.offset << ":" << range.size << ":"
2398  << range.stride;
2399 }
2400 
2401 /// Return the list of Range (i.e. offset, size, stride). Each Range
2402 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2403 /// with `b` at location `loc`.
2404 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2405  OpBuilder &b, Location loc) {
2406  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2407  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2408  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2410  unsigned rank = ranks[0];
2411  res.reserve(rank);
2412  for (unsigned idx = 0; idx < rank; ++idx) {
2413  Value offset =
2414  op.isDynamicOffset(idx)
2415  ? op.getDynamicOffset(idx)
2416  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
2417  Value size =
2418  op.isDynamicSize(idx)
2419  ? op.getDynamicSize(idx)
2420  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
2421  Value stride =
2422  op.isDynamicStride(idx)
2423  ? op.getDynamicStride(idx)
2424  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
2425  res.emplace_back(Range{offset, size, stride});
2426  }
2427  return res;
2428 }
2429 
2430 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2431 /// to deduce the result type for the given `sourceType`. Additionally, reduce
2432 /// the rank of the inferred result type if `currentResultType` is lower rank
2433 /// than `currentSourceType`. Use this signature if `sourceType` is updated
2434 /// together with the result type. In this case, it is important to compute
2435 /// the dropped dimensions using `currentSourceType` whose strides align with
2436 /// `currentResultType`.
2438  MemRefType currentResultType, MemRefType currentSourceType,
2439  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2440  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2441  auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2442  mixedSizes, mixedStrides)
2443  .cast<MemRefType>();
2445  computeMemRefRankReductionMask(currentSourceType, currentResultType,
2446  mixedSizes);
2447  // Return nullptr as failure mode.
2448  if (!unusedDims)
2449  return nullptr;
2450  SmallVector<int64_t> shape;
2451  for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) {
2452  if (unusedDims->test(sizes.index()))
2453  continue;
2454  shape.push_back(sizes.value());
2455  }
2456  AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
2457  if (!layoutMap.isIdentity())
2458  layoutMap = getProjectedMap(layoutMap, *unusedDims);
2459  return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
2460  nonRankReducedType.getMemorySpace());
2461 }
2462 
2463 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2464 /// to deduce the result type. Additionally, reduce the rank of the inferred
2465 /// result type if `currentResultType` is lower rank than `sourceType`.
2467  MemRefType currentResultType, MemRefType sourceType,
2468  ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2469  ArrayRef<OpFoldResult> mixedStrides) {
2470  return getCanonicalSubViewResultType(currentResultType, sourceType,
2471  sourceType, mixedOffsets, mixedSizes,
2472  mixedStrides);
2473 }
2474 
2475 /// Helper method to check if a `subview` operation is trivially a no-op. This
2476 /// is the case if the all offsets are zero, all strides are 1, and the source
2477 /// shape is same as the size of the subview. In such cases, the subview can
2478 /// be folded into its source.
2479 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
2480  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2481  return false;
2482 
2483  auto mixedOffsets = subViewOp.getMixedOffsets();
2484  auto mixedSizes = subViewOp.getMixedSizes();
2485  auto mixedStrides = subViewOp.getMixedStrides();
2486 
2487  // Check offsets are zero.
2488  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
2489  Optional<int64_t> intValue = getConstantIntValue(ofr);
2490  return !intValue || intValue.value() != 0;
2491  }))
2492  return false;
2493 
2494  // Check strides are one.
2495  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
2496  Optional<int64_t> intValue = getConstantIntValue(ofr);
2497  return !intValue || intValue.value() != 1;
2498  }))
2499  return false;
2500 
2501  // Check all size values are static and matches the (static) source shape.
2502  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
2503  for (const auto &size : llvm::enumerate(mixedSizes)) {
2504  Optional<int64_t> intValue = getConstantIntValue(size.value());
2505  if (!intValue || *intValue != sourceShape[size.index()])
2506  return false;
2507  }
2508  // All conditions met. The `SubViewOp` is foldable as a no-op.
2509  return true;
2510 }
2511 
2512 namespace {
2513 /// Pattern to rewrite a subview op with MemRefCast arguments.
2514 /// This essentially pushes memref.cast past its consuming subview when
2515 /// `canFoldIntoConsumerOp` is true.
2516 ///
2517 /// Example:
2518 /// ```
2519 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2520 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2521 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2522 /// ```
2523 /// is rewritten into:
2524 /// ```
2525 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2526 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2527 /// memref<3x4xf32, offset:?, strides:[?, 1]>
2528 /// ```
2529 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2530 public:
2532 
2533  LogicalResult matchAndRewrite(SubViewOp subViewOp,
2534  PatternRewriter &rewriter) const override {
2535  // Any constant operand, just return to let SubViewOpConstantFolder kick
2536  // in.
2537  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2538  return matchPattern(operand, matchConstantIndex());
2539  }))
2540  return failure();
2541 
2542  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
2543  if (!castOp)
2544  return failure();
2545 
2546  if (!CastOp::canFoldIntoConsumerOp(castOp))
2547  return failure();
2548 
2549  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
2550  // the MemRefCastOp source operand type to infer the result type and the
2551  // current SubViewOp source operand type to compute the dropped dimensions
2552  // if the operation is rank-reducing.
2553  auto resultType = getCanonicalSubViewResultType(
2554  subViewOp.getType(), subViewOp.getSourceType(),
2555  castOp.getSource().getType().cast<MemRefType>(),
2556  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2557  subViewOp.getMixedStrides());
2558  if (!resultType)
2559  return failure();
2560 
2561  Value newSubView = rewriter.create<SubViewOp>(
2562  subViewOp.getLoc(), resultType, castOp.getSource(),
2563  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
2564  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
2565  subViewOp.getStaticStrides());
2566  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2567  newSubView);
2568  return success();
2569  }
2570 };
2571 
2572 /// Canonicalize subview ops that are no-ops. When the source shape is not
2573 /// same as a result shape due to use of `affine_map`.
2574 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
2575 public:
2577 
2578  LogicalResult matchAndRewrite(SubViewOp subViewOp,
2579  PatternRewriter &rewriter) const override {
2580  if (!isTrivialSubViewOp(subViewOp))
2581  return failure();
2582  if (subViewOp.getSourceType() == subViewOp.getType()) {
2583  rewriter.replaceOp(subViewOp, subViewOp.getSource());
2584  return success();
2585  }
2586  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2587  subViewOp.getSource());
2588  return success();
2589  }
2590 };
2591 } // namespace
2592 
2593 /// Return the canonical type of the result of a subview.
2595  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2596  ArrayRef<OpFoldResult> mixedSizes,
2597  ArrayRef<OpFoldResult> mixedStrides) {
2598  return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
2599  mixedOffsets, mixedSizes,
2600  mixedStrides);
2601  }
2602 };
2603 
2604 /// A canonicalizer wrapper to replace SubViewOps.
2606  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2607  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2608  }
2609 };
2610 
2611 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2612  MLIRContext *context) {
2613  results
2616  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
2617 }
2618 
2619 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2620  auto resultShapedType = getResult().getType().cast<ShapedType>();
2621  auto sourceShapedType = getSource().getType().cast<ShapedType>();
2622 
2623  if (resultShapedType.hasStaticShape() &&
2624  resultShapedType == sourceShapedType) {
2625  return getViewSource();
2626  }
2627 
2628  return {};
2629 }
2630 
2631 //===----------------------------------------------------------------------===//
2632 // TransposeOp
2633 //===----------------------------------------------------------------------===//
2634 
2635 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
2636 static MemRefType inferTransposeResultType(MemRefType memRefType,
2637  AffineMap permutationMap) {
2638  auto rank = memRefType.getRank();
2639  auto originalSizes = memRefType.getShape();
2640  // Compute permuted sizes.
2641  SmallVector<int64_t, 4> sizes(rank, 0);
2642  for (const auto &en : llvm::enumerate(permutationMap.getResults()))
2643  sizes[en.index()] =
2644  originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2645 
2646  // Compute permuted strides.
2647  int64_t offset;
2648  SmallVector<int64_t, 4> strides;
2649  auto res = getStridesAndOffset(memRefType, strides, offset);
2650  assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2651  (void)res;
2652  auto map =
2653  makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2654  map = permutationMap ? map.compose(permutationMap) : map;
2655  return MemRefType::Builder(memRefType)
2656  .setShape(sizes)
2657  .setLayout(AffineMapAttr::get(map));
2658 }
2659 
2660 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2661  AffineMapAttr permutation,
2662  ArrayRef<NamedAttribute> attrs) {
2663  auto permutationMap = permutation.getValue();
2664  assert(permutationMap);
2665 
2666  auto memRefType = in.getType().cast<MemRefType>();
2667  // Compute result type.
2668  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2669 
2670  build(b, result, resultType, in, attrs);
2671  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
2672 }
2673 
2674 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2676  p << " " << getIn() << " " << getPermutation();
2677  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
2678  p << " : " << getIn().getType() << " to " << getType();
2679 }
2680 
2681 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2683  AffineMap permutation;
2684  MemRefType srcType, dstType;
2685  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2686  parser.parseOptionalAttrDict(result.attributes) ||
2687  parser.parseColonType(srcType) ||
2688  parser.resolveOperand(in, srcType, result.operands) ||
2689  parser.parseKeywordType("to", dstType) ||
2690  parser.addTypeToList(dstType, result.types))
2691  return failure();
2692 
2693  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
2694  AffineMapAttr::get(permutation));
2695  return success();
2696 }
2697 
2699  if (!getPermutation().isPermutation())
2700  return emitOpError("expected a permutation map");
2701  if (getPermutation().getNumDims() != getShapedType().getRank())
2702  return emitOpError("expected a permutation map of same rank as the input");
2703 
2704  auto srcType = getIn().getType().cast<MemRefType>();
2705  auto dstType = getType().cast<MemRefType>();
2706  auto transposedType = inferTransposeResultType(srcType, getPermutation());
2707  if (dstType != transposedType)
2708  return emitOpError("output type ")
2709  << dstType << " does not match transposed input type " << srcType
2710  << ", " << transposedType;
2711  return success();
2712 }
2713 
2714 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2715  if (succeeded(foldMemRefCast(*this)))
2716  return getResult();
2717  return {};
2718 }
2719 
2720 //===----------------------------------------------------------------------===//
2721 // ViewOp
2722 //===----------------------------------------------------------------------===//
2723 
2725  auto baseType = getOperand(0).getType().cast<MemRefType>();
2726  auto viewType = getType();
2727 
2728  // The base memref should have identity layout map (or none).
2729  if (!baseType.getLayout().isIdentity())
2730  return emitError("unsupported map for base memref type ") << baseType;
2731 
2732  // The result memref should have identity layout map (or none).
2733  if (!viewType.getLayout().isIdentity())
2734  return emitError("unsupported map for result memref type ") << viewType;
2735 
2736  // The base memref and the view memref should be in the same memory space.
2737  if (baseType.getMemorySpace() != viewType.getMemorySpace())
2738  return emitError("different memory spaces specified for base memref "
2739  "type ")
2740  << baseType << " and view memref type " << viewType;
2741 
2742  // Verify that we have the correct number of sizes for the result type.
2743  unsigned numDynamicDims = viewType.getNumDynamicDims();
2744  if (getSizes().size() != numDynamicDims)
2745  return emitError("incorrect number of size operands for type ") << viewType;
2746 
2747  return success();
2748 }
2749 
2750 Value ViewOp::getViewSource() { return getSource(); }
2751 
2752 namespace {
2753 
2754 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2756 
2757  LogicalResult matchAndRewrite(ViewOp viewOp,
2758  PatternRewriter &rewriter) const override {
2759  // Return if none of the operands are constants.
2760  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2761  return matchPattern(operand, matchConstantIndex());
2762  }))
2763  return failure();
2764 
2765  // Get result memref type.
2766  auto memrefType = viewOp.getType();
2767 
2768  // Get offset from old memref view type 'memRefType'.
2769  int64_t oldOffset;
2770  SmallVector<int64_t, 4> oldStrides;
2771  if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2772  return failure();
2773  assert(oldOffset == 0 && "Expected 0 offset");
2774 
2775  SmallVector<Value, 4> newOperands;
2776 
2777  // Offset cannot be folded into result type.
2778 
2779  // Fold any dynamic dim operands which are produced by a constant.
2780  SmallVector<int64_t, 4> newShapeConstants;
2781  newShapeConstants.reserve(memrefType.getRank());
2782 
2783  unsigned dynamicDimPos = 0;
2784  unsigned rank = memrefType.getRank();
2785  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2786  int64_t dimSize = memrefType.getDimSize(dim);
2787  // If this is already static dimension, keep it.
2788  if (!ShapedType::isDynamic(dimSize)) {
2789  newShapeConstants.push_back(dimSize);
2790  continue;
2791  }
2792  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
2793  if (auto constantIndexOp =
2794  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
2795  // Dynamic shape dimension will be folded.
2796  newShapeConstants.push_back(constantIndexOp.value());
2797  } else {
2798  // Dynamic shape dimension not folded; copy operand from old memref.
2799  newShapeConstants.push_back(dimSize);
2800  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
2801  }
2802  dynamicDimPos++;
2803  }
2804 
2805  // Create new memref type with constant folded dims.
2806  MemRefType newMemRefType =
2807  MemRefType::Builder(memrefType).setShape(newShapeConstants);
2808  // Nothing new, don't fold.
2809  if (newMemRefType == memrefType)
2810  return failure();
2811 
2812  // Create new ViewOp.
2813  auto newViewOp = rewriter.create<ViewOp>(
2814  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
2815  viewOp.getByteShift(), newOperands);
2816  // Insert a cast so we have the same type as the old memref type.
2817  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
2818  return success();
2819  }
2820 };
2821 
2822 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2824 
2825  LogicalResult matchAndRewrite(ViewOp viewOp,
2826  PatternRewriter &rewriter) const override {
2827  Value memrefOperand = viewOp.getOperand(0);
2828  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2829  if (!memrefCastOp)
2830  return failure();
2831  Value allocOperand = memrefCastOp.getOperand();
2832  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2833  if (!allocOp)
2834  return failure();
2835  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2836  viewOp.getByteShift(),
2837  viewOp.getSizes());
2838  return success();
2839  }
2840 };
2841 
2842 } // namespace
2843 
2844 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2845  MLIRContext *context) {
2846  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2847 }
2848 
2849 //===----------------------------------------------------------------------===//
2850 // AtomicRMWOp
2851 //===----------------------------------------------------------------------===//
2852 
2854  if (getMemRefType().getRank() != getNumOperands() - 2)
2855  return emitOpError(
2856  "expects the number of subscripts to be equal to memref rank");
2857  switch (getKind()) {
2858  case arith::AtomicRMWKind::addf:
2859  case arith::AtomicRMWKind::maxf:
2860  case arith::AtomicRMWKind::minf:
2861  case arith::AtomicRMWKind::mulf:
2862  if (!getValue().getType().isa<FloatType>())
2863  return emitOpError() << "with kind '"
2864  << arith::stringifyAtomicRMWKind(getKind())
2865  << "' expects a floating-point type";
2866  break;
2867  case arith::AtomicRMWKind::addi:
2868  case arith::AtomicRMWKind::maxs:
2869  case arith::AtomicRMWKind::maxu:
2870  case arith::AtomicRMWKind::mins:
2871  case arith::AtomicRMWKind::minu:
2872  case arith::AtomicRMWKind::muli:
2873  case arith::AtomicRMWKind::ori:
2874  case arith::AtomicRMWKind::andi:
2875  if (!getValue().getType().isa<IntegerType>())
2876  return emitOpError() << "with kind '"
2877  << arith::stringifyAtomicRMWKind(getKind())
2878  << "' expects an integer type";
2879  break;
2880  default:
2881  break;
2882  }
2883  return success();
2884 }
2885 
2886 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
2887  /// atomicrmw(memrefcast) -> atomicrmw
2888  if (succeeded(foldMemRefCast(*this, getValue())))
2889  return getResult();
2890  return OpFoldResult();
2891 }
2892 
2893 //===----------------------------------------------------------------------===//
2894 // TableGen'd op method definitions
2895 //===----------------------------------------------------------------------===//
2896 
2897 #define GET_OP_CLASSES
2898 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static FailureOr< AffineMap > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:1706
Include the generated interface declarations.
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
Pattern to rewrite a subview op with constant arguments.
Definition: Utils.h:39
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType)
Definition: MemRefOps.cpp:2344
SmallVector< OpFoldResult, 4 > getMixedSizes(ArrayAttr staticValues, ValueRange dynamicValues)
Return a vector of all the static and dynamic sizes.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into &#39;map&#39;.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:439
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
BlockListType & getBlocks()
Definition: Region.h:45
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2306
llvm::Optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Block represents an ordered list of Operations.
Definition: Block.h:29
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec, int64_t sentinel)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:344
A trait of region holding operations that define a new scope for automatic allocations, i.e., allocations that are freed when control is transferred back from the operation&#39;s region.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:492
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:1989
bool isa() const
Definition: Attributes.h:111
void push_back(Block *block)
Definition: Region.h:61
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:104
bool operator!=(StringAttr lhs, std::nullptr_t)
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:2605
This is the representation of an operand reference.
Operation & front()
Definition: Block.h:144
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:244
This trait indicates that the side effects of an operation includes the effects of operations nested ...
static FailureOr< MemRefType > computeExpandedType(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Definition: MemRefOps.cpp:1747
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual ParseResult parseComma()=0
Parse a , token.
Idiomatic saturated operations on offsets, sizes and strides.
Definition: MemRefOps.cpp:31
static constexpr const bool value
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of...
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Auxiliary range data structure to unpack the offset, size and stride operands into a list of triples...
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand...
Definition: MemRefOps.cpp:1613
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:358
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:300
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:242
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it...
Definition: MemRefOps.cpp:781
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseGreater()=0
Parse a &#39;>&#39; token.
OpFoldResult size
void addOperands(ValueRange newOperands)
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:117
U dyn_cast() const
Definition: Types.h:270
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource...
Definition: MemRefOps.cpp:322
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, unsigned memorySpace=0)
Return a MemRefType to which the type of the given value can be bufferized.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:355
static FailureOr< AffineMap > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:1822
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:395
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:347
Operation::operand_range getIndices(Operation *op)
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
OpFoldResult stride
bool isIndex() const
Definition: Types.cpp:28
Base type for affine expression.
Definition: AffineExpr.h:68
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
void addTypes(ArrayRef< Type > newTypes)
virtual ParseResult parseLess()=0
Parse a &#39;<&#39; token.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static WalkResult advance()
Definition: Visitors.h:51
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
This represents an operation in an abstracted form, suitable for use with the builder APIs...
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=llvm::None)
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
static WalkResult interrupt()
Definition: Visitors.h:50
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:307
OpFoldResult offset
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:93
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
AffineExpr operator+(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:244
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:2404
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:2606
static bool hasSideEffects(Operation *op)
bool isPermutation(ArrayRef< int64_t > permutation)
Check if permutation is a permutation of the range [0, permutation.size()).
Definition: Utils.cpp:187
static WalkResult skip()
Definition: Visitors.h:52
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:352
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:2437
static int resultIndex(int i)
Definition: Operator.cpp:308
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
static MemRefType computeCollapsedType(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation)
Definition: MemRefOps.cpp:1899
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:118
IndexType getIndexType()
Definition: Builders.cpp:48
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:112
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:303
U dyn_cast() const
Definition: Attributes.h:127
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:332
AffineMap getProjectedMap(AffineMap map, const llvm::SmallBitVector &projectedDimensions)
Returns the map that results from projecting out the dimensions specified in projectedDimensions.
Definition: AffineMap.cpp:715
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:753
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:164
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1252
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:2479
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:255
type_range getType() const
Definition: ValueRange.cpp:30
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
virtual ParseResult parseType(Type &result)=0
Parse a type.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
AffineMap makeStridedLinearLayoutMap(ArrayRef< int64_t > strides, int64_t offset, MLIRContext *context)
Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() represents a dynamic value)...
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap tp memRefType.
Definition: MemRefOps.cpp:2636
This class represents an operand of an operation.
Definition: Value.h:251
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:2595
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:161
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:345
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:185
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:245
static SliceVerificationResult isRankReducedMemRefType(MemRefType originalType, MemRefType candidateRankReducedType, ArrayRef< OpFoldResult > sizes)
Checks if original Type type can be rank reduced to reduced type.
Definition: MemRefOps.cpp:2318
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:508
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
bool isa() const
Definition: Types.h:254
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:221
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
This class represents success/failure for parsing-like operations that find it important to chain tog...
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:89
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:392
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1266
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
The following effect indicates that the operation allocates from some resource.
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:175
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:2594
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
static llvm::Optional< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:796
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
Square brackets surrounding zero or more operands.
U cast() const
Definition: Types.h:278
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:22