MLIR  21.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::memref;
29 
30 /// Materialize a single constant operation from a given attribute value with
31 /// the desired resultant type.
33  Attribute value, Type type,
34  Location loc) {
35  return arith::ConstantOp::materialize(builder, value, type, loc);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Common canonicalization pattern support logic
40 //===----------------------------------------------------------------------===//
41 
42 /// This is a common class used for patterns of the form
43 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
44 /// into the root operation directly.
45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46  bool folded = false;
47  for (OpOperand &operand : op->getOpOperands()) {
48  auto cast = operand.get().getDefiningOp<CastOp>();
49  if (cast && operand.get() != inner &&
50  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51  operand.set(cast.getOperand());
52  folded = true;
53  }
54  }
55  return success(folded);
56 }
57 
58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
59 /// type.
61  if (auto memref = llvm::dyn_cast<MemRefType>(type))
62  return RankedTensorType::get(memref.getShape(), memref.getElementType());
63  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
64  return UnrankedTensorType::get(memref.getElementType());
65  return NoneType::get(type.getContext());
66 }
67 
69  int64_t dim) {
70  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that sets values[i] to constValues[i] if the latter is a
91 /// static value, as indicated by ShapedType::kDynamic.
92 ///
93 /// If constValues[i] is dynamic, tries to extract a constant value from
94 /// value[i] to allow for additional folding opportunities. Also convertes all
95 /// existing attributes to index attributes. (They may be i64 attributes.)
97  ArrayRef<int64_t> constValues) {
98  assert(constValues.size() == values.size() &&
99  "incorrect number of const values");
100  for (auto [i, cstVal] : llvm::enumerate(constValues)) {
101  Builder builder(values[i].getContext());
102  if (!ShapedType::isDynamic(cstVal)) {
103  // Constant value is known, use it directly.
104  values[i] = builder.getIndexAttr(cstVal);
105  continue;
106  }
107  if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
108  // Try to extract a constant or convert an existing to index.
109  values[i] = builder.getIndexAttr(*cst);
110  }
111  }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // AllocOp / AllocaOp
116 //===----------------------------------------------------------------------===//
117 
118 void AllocOp::getAsmResultNames(
119  function_ref<void(Value, StringRef)> setNameFn) {
120  setNameFn(getResult(), "alloc");
121 }
122 
123 void AllocaOp::getAsmResultNames(
124  function_ref<void(Value, StringRef)> setNameFn) {
125  setNameFn(getResult(), "alloca");
126 }
127 
128 template <typename AllocLikeOp>
129 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
130  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
131  "applies to only alloc or alloca");
132  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
133  if (!memRefType)
134  return op.emitOpError("result must be a memref");
135 
136  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
137  return op.emitOpError("dimension operand count does not equal memref "
138  "dynamic dimension count");
139 
140  unsigned numSymbols = 0;
141  if (!memRefType.getLayout().isIdentity())
142  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
143  if (op.getSymbolOperands().size() != numSymbols)
144  return op.emitOpError("symbol operand count does not equal memref symbol "
145  "count: expected ")
146  << numSymbols << ", got " << op.getSymbolOperands().size();
147 
148  return success();
149 }
150 
151 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
152 
153 LogicalResult AllocaOp::verify() {
154  // An alloca op needs to have an ancestor with an allocation scope trait.
155  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
156  return emitOpError(
157  "requires an ancestor op with AutomaticAllocationScope trait");
158 
159  return verifyAllocLikeOp(*this);
160 }
161 
162 namespace {
163 /// Fold constant dimensions into an alloc like operation.
164 template <typename AllocLikeOp>
165 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
167 
168  LogicalResult matchAndRewrite(AllocLikeOp alloc,
169  PatternRewriter &rewriter) const override {
170  // Check to see if any dimensions operands are constants. If so, we can
171  // substitute and drop them.
172  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
173  APInt constSizeArg;
174  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
175  return false;
176  return constSizeArg.isNonNegative();
177  }))
178  return failure();
179 
180  auto memrefType = alloc.getType();
181 
182  // Ok, we have one or more constant operands. Collect the non-constant ones
183  // and keep track of the resultant memref type to build.
184  SmallVector<int64_t, 4> newShapeConstants;
185  newShapeConstants.reserve(memrefType.getRank());
186  SmallVector<Value, 4> dynamicSizes;
187 
188  unsigned dynamicDimPos = 0;
189  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
190  int64_t dimSize = memrefType.getDimSize(dim);
191  // If this is already static dimension, keep it.
192  if (!ShapedType::isDynamic(dimSize)) {
193  newShapeConstants.push_back(dimSize);
194  continue;
195  }
196  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
197  APInt constSizeArg;
198  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
199  constSizeArg.isNonNegative()) {
200  // Dynamic shape dimension will be folded.
201  newShapeConstants.push_back(constSizeArg.getZExtValue());
202  } else {
203  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
204  newShapeConstants.push_back(ShapedType::kDynamic);
205  dynamicSizes.push_back(dynamicSize);
206  }
207  dynamicDimPos++;
208  }
209 
210  // Create new memref type (which will have fewer dynamic dimensions).
211  MemRefType newMemRefType =
212  MemRefType::Builder(memrefType).setShape(newShapeConstants);
213  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
214 
215  // Create and insert the alloc op for the new memref.
216  auto newAlloc = rewriter.create<AllocLikeOp>(
217  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
218  alloc.getAlignmentAttr());
219  // Insert a cast so we have the same type as the old alloc.
220  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
221  return success();
222  }
223 };
224 
225 /// Fold alloc operations with no users or only store and dealloc uses.
226 template <typename T>
227 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
229 
230  LogicalResult matchAndRewrite(T alloc,
231  PatternRewriter &rewriter) const override {
232  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
233  if (auto storeOp = dyn_cast<StoreOp>(op))
234  return storeOp.getValue() == alloc;
235  return !isa<DeallocOp>(op);
236  }))
237  return failure();
238 
239  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
240  rewriter.eraseOp(user);
241 
242  rewriter.eraseOp(alloc);
243  return success();
244  }
245 };
246 } // namespace
247 
248 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
249  MLIRContext *context) {
250  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
251 }
252 
253 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
254  MLIRContext *context) {
255  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
256  context);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // ReallocOp
261 //===----------------------------------------------------------------------===//
262 
263 LogicalResult ReallocOp::verify() {
264  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
265  MemRefType resultType = getType();
266 
267  // The source memref should have identity layout (or none).
268  if (!sourceType.getLayout().isIdentity())
269  return emitError("unsupported layout for source memref type ")
270  << sourceType;
271 
272  // The result memref should have identity layout (or none).
273  if (!resultType.getLayout().isIdentity())
274  return emitError("unsupported layout for result memref type ")
275  << resultType;
276 
277  // The source memref and the result memref should be in the same memory space.
278  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
279  return emitError("different memory spaces specified for source memref "
280  "type ")
281  << sourceType << " and result memref type " << resultType;
282 
283  // The source memref and the result memref should have the same element type.
284  if (sourceType.getElementType() != resultType.getElementType())
285  return emitError("different element types specified for source memref "
286  "type ")
287  << sourceType << " and result memref type " << resultType;
288 
289  // Verify that we have the dynamic dimension operand when it is needed.
290  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
291  return emitError("missing dimension operand for result type ")
292  << resultType;
293  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
294  return emitError("unnecessary dimension operand for result type ")
295  << resultType;
296 
297  return success();
298 }
299 
300 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
301  MLIRContext *context) {
302  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // AllocaScopeOp
307 //===----------------------------------------------------------------------===//
308 
310  bool printBlockTerminators = false;
311 
312  p << ' ';
313  if (!getResults().empty()) {
314  p << " -> (" << getResultTypes() << ")";
315  printBlockTerminators = true;
316  }
317  p << ' ';
318  p.printRegion(getBodyRegion(),
319  /*printEntryBlockArgs=*/false,
320  /*printBlockTerminators=*/printBlockTerminators);
321  p.printOptionalAttrDict((*this)->getAttrs());
322 }
323 
324 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
325  // Create a region for the body.
326  result.regions.reserve(1);
327  Region *bodyRegion = result.addRegion();
328 
329  // Parse optional results type list.
330  if (parser.parseOptionalArrowTypeList(result.types))
331  return failure();
332 
333  // Parse the body region.
334  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
335  return failure();
336  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
337  result.location);
338 
339  // Parse the optional attribute list.
340  if (parser.parseOptionalAttrDict(result.attributes))
341  return failure();
342 
343  return success();
344 }
345 
346 void AllocaScopeOp::getSuccessorRegions(
348  if (!point.isParent()) {
349  regions.push_back(RegionSuccessor(getResults()));
350  return;
351  }
352 
353  regions.push_back(RegionSuccessor(&getBodyRegion()));
354 }
355 
356 /// Given an operation, return whether this op is guaranteed to
357 /// allocate an AutomaticAllocationScopeResource
359  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
360  if (!interface)
361  return false;
362  for (auto res : op->getResults()) {
363  if (auto effect =
364  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
365  if (isa<SideEffects::AutomaticAllocationScopeResource>(
366  effect->getResource()))
367  return true;
368  }
369  }
370  return false;
371 }
372 
373 /// Given an operation, return whether this op itself could
374 /// allocate an AutomaticAllocationScopeResource. Note that
375 /// this will not check whether an operation contained within
376 /// the op can allocate.
378  // This op itself doesn't create a stack allocation,
379  // the inner allocation should be handled separately.
381  return false;
382  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
383  if (!interface)
384  return true;
385  for (auto res : op->getResults()) {
386  if (auto effect =
387  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
388  if (isa<SideEffects::AutomaticAllocationScopeResource>(
389  effect->getResource()))
390  return true;
391  }
392  }
393  return false;
394 }
395 
396 /// Return whether this op is the last non terminating op
397 /// in a region. That is to say, it is in a one-block region
398 /// and is only followed by a terminator. This prevents
399 /// extending the lifetime of allocations.
401  return op->getNextNode() == op->getBlock()->getTerminator() &&
402  llvm::hasSingleElement(op->getParentRegion()->getBlocks());
403 }
404 
405 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
406 /// or it contains no allocation.
407 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
409 
410  LogicalResult matchAndRewrite(AllocaScopeOp op,
411  PatternRewriter &rewriter) const override {
412  bool hasPotentialAlloca =
413  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
414  if (alloc == op)
415  return WalkResult::advance();
417  return WalkResult::interrupt();
418  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
419  return WalkResult::skip();
420  return WalkResult::advance();
421  }).wasInterrupted();
422 
423  // If this contains no potential allocation, it is always legal to
424  // inline. Otherwise, consider two conditions:
425  if (hasPotentialAlloca) {
426  // If the parent isn't an allocation scope, or we are not the last
427  // non-terminator op in the parent, we will extend the lifetime.
428  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
429  return failure();
430  if (!lastNonTerminatorInRegion(op))
431  return failure();
432  }
433 
434  Block *block = &op.getRegion().front();
435  Operation *terminator = block->getTerminator();
436  ValueRange results = terminator->getOperands();
437  rewriter.inlineBlockBefore(block, op);
438  rewriter.replaceOp(op, results);
439  rewriter.eraseOp(terminator);
440  return success();
441  }
442 };
443 
444 /// Move allocations into an allocation scope, if it is legal to
445 /// move them (e.g. their operands are available at the location
446 /// the op would be moved to).
447 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
449 
450  LogicalResult matchAndRewrite(AllocaScopeOp op,
451  PatternRewriter &rewriter) const override {
452 
453  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
454  return failure();
455 
456  Operation *lastParentWithoutScope = op->getParentOp();
457 
458  if (!lastParentWithoutScope ||
459  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
460  return failure();
461 
462  // Only apply to if this is this last non-terminator
463  // op in the block (lest lifetime be extended) of a one
464  // block region
465  if (!lastNonTerminatorInRegion(op) ||
466  !lastNonTerminatorInRegion(lastParentWithoutScope))
467  return failure();
468 
469  while (!lastParentWithoutScope->getParentOp()
471  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
472  if (!lastParentWithoutScope ||
473  !lastNonTerminatorInRegion(lastParentWithoutScope))
474  return failure();
475  }
476  assert(lastParentWithoutScope->getParentOp()
478 
479  Region *containingRegion = nullptr;
480  for (auto &r : lastParentWithoutScope->getRegions()) {
481  if (r.isAncestor(op->getParentRegion())) {
482  assert(containingRegion == nullptr &&
483  "only one region can contain the op");
484  containingRegion = &r;
485  }
486  }
487  assert(containingRegion && "op must be contained in a region");
488 
489  SmallVector<Operation *> toHoist;
490  op->walk([&](Operation *alloc) {
492  return WalkResult::skip();
493 
494  // If any operand is not defined before the location of
495  // lastParentWithoutScope (i.e. where we would hoist to), skip.
496  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
497  return containingRegion->isAncestor(v.getParentRegion());
498  }))
499  return WalkResult::skip();
500  toHoist.push_back(alloc);
501  return WalkResult::advance();
502  });
503 
504  if (toHoist.empty())
505  return failure();
506  rewriter.setInsertionPoint(lastParentWithoutScope);
507  for (auto *op : toHoist) {
508  auto *cloned = rewriter.clone(*op);
509  rewriter.replaceOp(op, cloned->getResults());
510  }
511  return success();
512  }
513 };
514 
515 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
516  MLIRContext *context) {
517  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // AssumeAlignmentOp
522 //===----------------------------------------------------------------------===//
523 
524 LogicalResult AssumeAlignmentOp::verify() {
525  if (!llvm::isPowerOf2_32(getAlignment()))
526  return emitOpError("alignment must be power of 2");
527  return success();
528 }
529 
530 void AssumeAlignmentOp::getAsmResultNames(
531  function_ref<void(Value, StringRef)> setNameFn) {
532  setNameFn(getResult(), "assume_align");
533 }
534 
535 //===----------------------------------------------------------------------===//
536 // CastOp
537 //===----------------------------------------------------------------------===//
538 
539 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
540  setNameFn(getResult(), "cast");
541 }
542 
543 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
544 /// source memref. This is useful to fold a memref.cast into a consuming op
545 /// and implement canonicalization patterns for ops in different dialects that
546 /// may consume the results of memref.cast operations. Such foldable memref.cast
547 /// operations are typically inserted as `view` and `subview` ops are
548 /// canonicalized, to preserve the type compatibility of their uses.
549 ///
550 /// Returns true when all conditions are met:
551 /// 1. source and result are ranked memrefs with strided semantics and same
552 /// element type and rank.
553 /// 2. each of the source's size, offset or stride has more static information
554 /// than the corresponding result's size, offset or stride.
555 ///
556 /// Example 1:
557 /// ```mlir
558 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
559 /// %2 = consumer %1 ... : memref<?x?xf32> ...
560 /// ```
561 ///
562 /// may fold into:
563 ///
564 /// ```mlir
565 /// %2 = consumer %0 ... : memref<8x16xf32> ...
566 /// ```
567 ///
568 /// Example 2:
569 /// ```
570 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
571 /// to memref<?x?xf32>
572 /// consumer %1 : memref<?x?xf32> ...
573 /// ```
574 ///
575 /// may fold into:
576 ///
577 /// ```
578 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
579 /// ```
580 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
581  MemRefType sourceType =
582  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
583  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
584 
585  // Requires ranked MemRefType.
586  if (!sourceType || !resultType)
587  return false;
588 
589  // Requires same elemental type.
590  if (sourceType.getElementType() != resultType.getElementType())
591  return false;
592 
593  // Requires same rank.
594  if (sourceType.getRank() != resultType.getRank())
595  return false;
596 
597  // Only fold casts between strided memref forms.
598  int64_t sourceOffset, resultOffset;
599  SmallVector<int64_t, 4> sourceStrides, resultStrides;
600  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
601  failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
602  return false;
603 
604  // If cast is towards more static sizes along any dimension, don't fold.
605  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
606  auto ss = std::get<0>(it), st = std::get<1>(it);
607  if (ss != st)
608  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
609  return false;
610  }
611 
612  // If cast is towards more static offset along any dimension, don't fold.
613  if (sourceOffset != resultOffset)
614  if (ShapedType::isDynamic(sourceOffset) &&
615  !ShapedType::isDynamic(resultOffset))
616  return false;
617 
618  // If cast is towards more static strides along any dimension, don't fold.
619  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
620  auto ss = std::get<0>(it), st = std::get<1>(it);
621  if (ss != st)
622  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
623  return false;
624  }
625 
626  return true;
627 }
628 
629 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
630  if (inputs.size() != 1 || outputs.size() != 1)
631  return false;
632  Type a = inputs.front(), b = outputs.front();
633  auto aT = llvm::dyn_cast<MemRefType>(a);
634  auto bT = llvm::dyn_cast<MemRefType>(b);
635 
636  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
637  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
638 
639  if (aT && bT) {
640  if (aT.getElementType() != bT.getElementType())
641  return false;
642  if (aT.getLayout() != bT.getLayout()) {
643  int64_t aOffset, bOffset;
644  SmallVector<int64_t, 4> aStrides, bStrides;
645  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
646  failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
647  aStrides.size() != bStrides.size())
648  return false;
649 
650  // Strides along a dimension/offset are compatible if the value in the
651  // source memref is static and the value in the target memref is the
652  // same. They are also compatible if either one is dynamic (see
653  // description of MemRefCastOp for details).
654  auto checkCompatible = [](int64_t a, int64_t b) {
655  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
656  };
657  if (!checkCompatible(aOffset, bOffset))
658  return false;
659  for (const auto &aStride : enumerate(aStrides))
660  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
661  return false;
662  }
663  if (aT.getMemorySpace() != bT.getMemorySpace())
664  return false;
665 
666  // They must have the same rank, and any specified dimensions must match.
667  if (aT.getRank() != bT.getRank())
668  return false;
669 
670  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
671  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
672  if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
673  aDim != bDim)
674  return false;
675  }
676  return true;
677  } else {
678  if (!aT && !uaT)
679  return false;
680  if (!bT && !ubT)
681  return false;
682  // Unranked to unranked casting is unsupported
683  if (uaT && ubT)
684  return false;
685 
686  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
687  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
688  if (aEltType != bEltType)
689  return false;
690 
691  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
692  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
693  return aMemSpace == bMemSpace;
694  }
695 
696  return false;
697 }
698 
699 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
700  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
701 }
702 
703 //===----------------------------------------------------------------------===//
704 // CopyOp
705 //===----------------------------------------------------------------------===//
706 
707 namespace {
708 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
709 /// and element type, the cast can be skipped. Such CastOps only cast the layout
710 /// of the type.
711 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
713 
714  LogicalResult matchAndRewrite(CopyOp copyOp,
715  PatternRewriter &rewriter) const override {
716  bool modified = false;
717 
718  // Check source.
719  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
720  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
721  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
722 
723  if (fromType && toType) {
724  if (fromType.getShape() == toType.getShape() &&
725  fromType.getElementType() == toType.getElementType()) {
726  rewriter.modifyOpInPlace(copyOp, [&] {
727  copyOp.getSourceMutable().assign(castOp.getSource());
728  });
729  modified = true;
730  }
731  }
732  }
733 
734  // Check target.
735  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
736  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
737  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
738 
739  if (fromType && toType) {
740  if (fromType.getShape() == toType.getShape() &&
741  fromType.getElementType() == toType.getElementType()) {
742  rewriter.modifyOpInPlace(copyOp, [&] {
743  copyOp.getTargetMutable().assign(castOp.getSource());
744  });
745  modified = true;
746  }
747  }
748  }
749 
750  return success(modified);
751  }
752 };
753 
754 /// Fold memref.copy(%x, %x).
755 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
757 
758  LogicalResult matchAndRewrite(CopyOp copyOp,
759  PatternRewriter &rewriter) const override {
760  if (copyOp.getSource() != copyOp.getTarget())
761  return failure();
762 
763  rewriter.eraseOp(copyOp);
764  return success();
765  }
766 };
767 
768 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
770 
771  static bool isEmptyMemRef(BaseMemRefType type) {
772  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
773  }
774 
775  LogicalResult matchAndRewrite(CopyOp copyOp,
776  PatternRewriter &rewriter) const override {
777  if (isEmptyMemRef(copyOp.getSource().getType()) ||
778  isEmptyMemRef(copyOp.getTarget().getType())) {
779  rewriter.eraseOp(copyOp);
780  return success();
781  }
782 
783  return failure();
784  }
785 };
786 } // namespace
787 
788 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
789  MLIRContext *context) {
790  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
791 }
792 
793 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
795  /// copy(memrefcast) -> copy
796  bool folded = false;
797  Operation *op = *this;
798  for (OpOperand &operand : op->getOpOperands()) {
799  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
800  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
801  operand.set(castOp.getOperand());
802  folded = true;
803  }
804  }
805  return success(folded);
806 }
807 
808 //===----------------------------------------------------------------------===//
809 // DeallocOp
810 //===----------------------------------------------------------------------===//
811 
812 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
814  /// dealloc(memrefcast) -> dealloc
815  return foldMemRefCast(*this);
816 }
817 
818 //===----------------------------------------------------------------------===//
819 // DimOp
820 //===----------------------------------------------------------------------===//
821 
822 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
823  setNameFn(getResult(), "dim");
824 }
825 
826 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
827  int64_t index) {
828  auto loc = result.location;
829  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
830  build(builder, result, source, indexValue);
831 }
832 
833 std::optional<int64_t> DimOp::getConstantIndex() {
834  return getConstantIntValue(getIndex());
835 }
836 
837 Speculation::Speculatability DimOp::getSpeculatability() {
838  auto constantIndex = getConstantIndex();
839  if (!constantIndex)
841 
842  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
843  if (!rankedSourceType)
845 
846  if (rankedSourceType.getRank() <= constantIndex)
848 
850 }
851 
852 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
853  SetIntLatticeFn setResultRange) {
854  setResultRange(getResult(),
855  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
856 }
857 
858 /// Return a map with key being elements in `vals` and data being number of
859 /// occurences of it. Use std::map, since the `vals` here are strides and the
860 /// dynamic stride value is the same as the tombstone value for
861 /// `DenseMap<int64_t>`.
862 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
863  std::map<int64_t, unsigned> numOccurences;
864  for (auto val : vals)
865  numOccurences[val]++;
866  return numOccurences;
867 }
868 
869 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
870 /// to be a subset of `originalType` with some `1` entries erased, return the
871 /// set of indices that specifies which of the entries of `originalShape` are
872 /// dropped to obtain `reducedShape`.
873 /// This accounts for cases where there are multiple unit-dims, but only a
874 /// subset of those are dropped. For MemRefTypes these can be disambiguated
875 /// using the strides. If a dimension is dropped the stride must be dropped too.
876 static FailureOr<llvm::SmallBitVector>
877 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
878  ArrayRef<OpFoldResult> sizes) {
879  llvm::SmallBitVector unusedDims(originalType.getRank());
880  if (originalType.getRank() == reducedType.getRank())
881  return unusedDims;
882 
883  for (const auto &dim : llvm::enumerate(sizes))
884  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
885  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
886  unusedDims.set(dim.index());
887 
888  // Early exit for the case where the number of unused dims matches the number
889  // of ranks reduced.
890  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
891  originalType.getRank())
892  return unusedDims;
893 
894  SmallVector<int64_t> originalStrides, candidateStrides;
895  int64_t originalOffset, candidateOffset;
896  if (failed(
897  originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
898  failed(
899  reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
900  return failure();
901 
902  // For memrefs, a dimension is truly dropped if its corresponding stride is
903  // also dropped. This is particularly important when more than one of the dims
904  // is 1. Track the number of occurences of the strides in the original type
905  // and the candidate type. For each unused dim that stride should not be
906  // present in the candidate type. Note that there could be multiple dimensions
907  // that have the same size. We dont need to exactly figure out which dim
908  // corresponds to which stride, we just need to verify that the number of
909  // reptitions of a stride in the original + number of unused dims with that
910  // stride == number of repititions of a stride in the candidate.
911  std::map<int64_t, unsigned> currUnaccountedStrides =
912  getNumOccurences(originalStrides);
913  std::map<int64_t, unsigned> candidateStridesNumOccurences =
914  getNumOccurences(candidateStrides);
915  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
916  if (!unusedDims.test(dim))
917  continue;
918  int64_t originalStride = originalStrides[dim];
919  if (currUnaccountedStrides[originalStride] >
920  candidateStridesNumOccurences[originalStride]) {
921  // This dim can be treated as dropped.
922  currUnaccountedStrides[originalStride]--;
923  continue;
924  }
925  if (currUnaccountedStrides[originalStride] ==
926  candidateStridesNumOccurences[originalStride]) {
927  // The stride for this is not dropped. Keep as is.
928  unusedDims.reset(dim);
929  continue;
930  }
931  if (currUnaccountedStrides[originalStride] <
932  candidateStridesNumOccurences[originalStride]) {
933  // This should never happen. Cant have a stride in the reduced rank type
934  // that wasnt in the original one.
935  return failure();
936  }
937  }
938 
939  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
940  originalType.getRank())
941  return failure();
942  return unusedDims;
943 }
944 
945 llvm::SmallBitVector SubViewOp::getDroppedDims() {
946  MemRefType sourceType = getSourceType();
947  MemRefType resultType = getType();
948  FailureOr<llvm::SmallBitVector> unusedDims =
949  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
950  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
951  return *unusedDims;
952 }
953 
954 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
955  // All forms of folding require a known index.
956  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
957  if (!index)
958  return {};
959 
960  // Folding for unranked types (UnrankedMemRefType) is not supported.
961  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
962  if (!memrefType)
963  return {};
964 
965  // Out of bound indices produce undefined behavior but are still valid IR.
966  // Don't choke on them.
967  int64_t indexVal = index.getInt();
968  if (indexVal < 0 || indexVal >= memrefType.getRank())
969  return {};
970 
971  // Fold if the shape extent along the given index is known.
972  if (!memrefType.isDynamicDim(index.getInt())) {
973  Builder builder(getContext());
974  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
975  }
976 
977  // The size at the given index is now known to be a dynamic size.
978  unsigned unsignedIndex = index.getValue().getZExtValue();
979 
980  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
981  Operation *definingOp = getSource().getDefiningOp();
982 
983  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
984  return *(alloc.getDynamicSizes().begin() +
985  memrefType.getDynamicDimIndex(unsignedIndex));
986 
987  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
988  return *(alloca.getDynamicSizes().begin() +
989  memrefType.getDynamicDimIndex(unsignedIndex));
990 
991  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
992  return *(view.getDynamicSizes().begin() +
993  memrefType.getDynamicDimIndex(unsignedIndex));
994 
995  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
996  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
997  unsigned resultIndex = 0;
998  unsigned sourceRank = subview.getSourceType().getRank();
999  unsigned sourceIndex = 0;
1000  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1001  if (unusedDims.test(i))
1002  continue;
1003  if (resultIndex == unsignedIndex) {
1004  sourceIndex = i;
1005  break;
1006  }
1007  resultIndex++;
1008  }
1009  assert(subview.isDynamicSize(sourceIndex) &&
1010  "expected dynamic subview size");
1011  return subview.getDynamicSize(sourceIndex);
1012  }
1013 
1014  if (auto sizeInterface =
1015  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1016  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1017  "Expected dynamic subview size");
1018  return sizeInterface.getDynamicSize(unsignedIndex);
1019  }
1020 
1021  // dim(memrefcast) -> dim
1022  if (succeeded(foldMemRefCast(*this)))
1023  return getResult();
1024 
1025  return {};
1026 }
1027 
1028 namespace {
1029 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1030 /// operand.
1031 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1033 
1034  LogicalResult matchAndRewrite(DimOp dim,
1035  PatternRewriter &rewriter) const override {
1036  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1037 
1038  if (!reshape)
1039  return rewriter.notifyMatchFailure(
1040  dim, "Dim op is not defined by a reshape op.");
1041 
1042  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1043  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1044  // cheaply check that either of the following conditions hold:
1045  // 1. dim.getIndex() is defined in the same block as reshape but before
1046  // reshape.
1047  // 2. dim.getIndex() is defined in a parent block of
1048  // reshape.
1049 
1050  // Check condition 1
1051  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1052  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1053  if (reshape->isBeforeInBlock(definingOp)) {
1054  return rewriter.notifyMatchFailure(
1055  dim,
1056  "dim.getIndex is not defined before reshape in the same block.");
1057  }
1058  } // else dim.getIndex is a block argument to reshape->getBlock and
1059  // dominates reshape
1060  } // Check condition 2
1061  else if (dim->getBlock() != reshape->getBlock() &&
1062  !dim.getIndex().getParentRegion()->isProperAncestor(
1063  reshape->getParentRegion())) {
1064  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1065  // already know dim.getIndex() dominates reshape without calling
1066  // `isProperAncestor`
1067  return rewriter.notifyMatchFailure(
1068  dim, "dim.getIndex does not dominate reshape.");
1069  }
1070 
1071  // Place the load directly after the reshape to ensure that the shape memref
1072  // was not mutated.
1073  rewriter.setInsertionPointAfter(reshape);
1074  Location loc = dim.getLoc();
1075  Value load =
1076  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1077  if (load.getType() != dim.getType())
1078  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1079  rewriter.replaceOp(dim, load);
1080  return success();
1081  }
1082 };
1083 
1084 } // namespace
1085 
1086 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1087  MLIRContext *context) {
1088  results.add<DimOfMemRefReshape>(context);
1089 }
1090 
1091 // ---------------------------------------------------------------------------
1092 // DmaStartOp
1093 // ---------------------------------------------------------------------------
1094 
1095 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1096  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1097  ValueRange destIndices, Value numElements,
1098  Value tagMemRef, ValueRange tagIndices, Value stride,
1099  Value elementsPerStride) {
1100  result.addOperands(srcMemRef);
1101  result.addOperands(srcIndices);
1102  result.addOperands(destMemRef);
1103  result.addOperands(destIndices);
1104  result.addOperands({numElements, tagMemRef});
1105  result.addOperands(tagIndices);
1106  if (stride)
1107  result.addOperands({stride, elementsPerStride});
1108 }
1109 
1111  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1112  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1113  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1114  if (isStrided())
1115  p << ", " << getStride() << ", " << getNumElementsPerStride();
1116 
1117  p.printOptionalAttrDict((*this)->getAttrs());
1118  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1119  << ", " << getTagMemRef().getType();
1120 }
1121 
1122 // Parse DmaStartOp.
1123 // Ex:
1124 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1125 // %tag[%index], %stride, %num_elt_per_stride :
1126 // : memref<3076 x f32, 0>,
1127 // memref<1024 x f32, 2>,
1128 // memref<1 x i32>
1129 //
1130 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1131  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1133  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1135  OpAsmParser::UnresolvedOperand numElementsInfo;
1136  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1139 
1140  SmallVector<Type, 3> types;
1141  auto indexType = parser.getBuilder().getIndexType();
1142 
1143  // Parse and resolve the following list of operands:
1144  // *) source memref followed by its indices (in square brackets).
1145  // *) destination memref followed by its indices (in square brackets).
1146  // *) dma size in KiB.
1147  if (parser.parseOperand(srcMemRefInfo) ||
1148  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1149  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1150  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1151  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1152  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1153  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1154  return failure();
1155 
1156  // Parse optional stride and elements per stride.
1157  if (parser.parseTrailingOperandList(strideInfo))
1158  return failure();
1159 
1160  bool isStrided = strideInfo.size() == 2;
1161  if (!strideInfo.empty() && !isStrided) {
1162  return parser.emitError(parser.getNameLoc(),
1163  "expected two stride related operands");
1164  }
1165 
1166  if (parser.parseColonTypeList(types))
1167  return failure();
1168  if (types.size() != 3)
1169  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1170 
1171  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1172  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1173  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1174  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1175  // size should be an index.
1176  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1177  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1178  // tag indices should be index.
1179  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1180  return failure();
1181 
1182  if (isStrided) {
1183  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1184  return failure();
1185  }
1186 
1187  return success();
1188 }
1189 
1190 LogicalResult DmaStartOp::verify() {
1191  unsigned numOperands = getNumOperands();
1192 
1193  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1194  // the number of elements.
1195  if (numOperands < 4)
1196  return emitOpError("expected at least 4 operands");
1197 
1198  // Check types of operands. The order of these calls is important: the later
1199  // calls rely on some type properties to compute the operand position.
1200  // 1. Source memref.
1201  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1202  return emitOpError("expected source to be of memref type");
1203  if (numOperands < getSrcMemRefRank() + 4)
1204  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1205  << " operands";
1206  if (!getSrcIndices().empty() &&
1207  !llvm::all_of(getSrcIndices().getTypes(),
1208  [](Type t) { return t.isIndex(); }))
1209  return emitOpError("expected source indices to be of index type");
1210 
1211  // 2. Destination memref.
1212  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1213  return emitOpError("expected destination to be of memref type");
1214  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1215  if (numOperands < numExpectedOperands)
1216  return emitOpError() << "expected at least " << numExpectedOperands
1217  << " operands";
1218  if (!getDstIndices().empty() &&
1219  !llvm::all_of(getDstIndices().getTypes(),
1220  [](Type t) { return t.isIndex(); }))
1221  return emitOpError("expected destination indices to be of index type");
1222 
1223  // 3. Number of elements.
1224  if (!getNumElements().getType().isIndex())
1225  return emitOpError("expected num elements to be of index type");
1226 
1227  // 4. Tag memref.
1228  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1229  return emitOpError("expected tag to be of memref type");
1230  numExpectedOperands += getTagMemRefRank();
1231  if (numOperands < numExpectedOperands)
1232  return emitOpError() << "expected at least " << numExpectedOperands
1233  << " operands";
1234  if (!getTagIndices().empty() &&
1235  !llvm::all_of(getTagIndices().getTypes(),
1236  [](Type t) { return t.isIndex(); }))
1237  return emitOpError("expected tag indices to be of index type");
1238 
1239  // Optional stride-related operands must be either both present or both
1240  // absent.
1241  if (numOperands != numExpectedOperands &&
1242  numOperands != numExpectedOperands + 2)
1243  return emitOpError("incorrect number of operands");
1244 
1245  // 5. Strides.
1246  if (isStrided()) {
1247  if (!getStride().getType().isIndex() ||
1248  !getNumElementsPerStride().getType().isIndex())
1249  return emitOpError(
1250  "expected stride and num elements per stride to be of type index");
1251  }
1252 
1253  return success();
1254 }
1255 
1256 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1257  SmallVectorImpl<OpFoldResult> &results) {
1258  /// dma_start(memrefcast) -> dma_start
1259  return foldMemRefCast(*this);
1260 }
1261 
1262 // ---------------------------------------------------------------------------
1263 // DmaWaitOp
1264 // ---------------------------------------------------------------------------
1265 
1266 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1267  SmallVectorImpl<OpFoldResult> &results) {
1268  /// dma_wait(memrefcast) -> dma_wait
1269  return foldMemRefCast(*this);
1270 }
1271 
1272 LogicalResult DmaWaitOp::verify() {
1273  // Check that the number of tag indices matches the tagMemRef rank.
1274  unsigned numTagIndices = getTagIndices().size();
1275  unsigned tagMemRefRank = getTagMemRefRank();
1276  if (numTagIndices != tagMemRefRank)
1277  return emitOpError() << "expected tagIndices to have the same number of "
1278  "elements as the tagMemRef rank, expected "
1279  << tagMemRefRank << ", but got " << numTagIndices;
1280  return success();
1281 }
1282 
1283 //===----------------------------------------------------------------------===//
1284 // ExtractAlignedPointerAsIndexOp
1285 //===----------------------------------------------------------------------===//
1286 
1287 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1288  function_ref<void(Value, StringRef)> setNameFn) {
1289  setNameFn(getResult(), "intptr");
1290 }
1291 
1292 //===----------------------------------------------------------------------===//
1293 // ExtractStridedMetadataOp
1294 //===----------------------------------------------------------------------===//
1295 
1296 /// The number and type of the results are inferred from the
1297 /// shape of the source.
1298 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1299  MLIRContext *context, std::optional<Location> location,
1300  ExtractStridedMetadataOp::Adaptor adaptor,
1301  SmallVectorImpl<Type> &inferredReturnTypes) {
1302  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1303  if (!sourceType)
1304  return failure();
1305 
1306  unsigned sourceRank = sourceType.getRank();
1307  IndexType indexType = IndexType::get(context);
1308  auto memrefType =
1309  MemRefType::get({}, sourceType.getElementType(),
1310  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1311  // Base.
1312  inferredReturnTypes.push_back(memrefType);
1313  // Offset.
1314  inferredReturnTypes.push_back(indexType);
1315  // Sizes and strides.
1316  for (unsigned i = 0; i < sourceRank * 2; ++i)
1317  inferredReturnTypes.push_back(indexType);
1318  return success();
1319 }
1320 
1321 void ExtractStridedMetadataOp::getAsmResultNames(
1322  function_ref<void(Value, StringRef)> setNameFn) {
1323  setNameFn(getBaseBuffer(), "base_buffer");
1324  setNameFn(getOffset(), "offset");
1325  // For multi-result to work properly with pretty names and packed syntax `x:3`
1326  // we can only give a pretty name to the first value in the pack.
1327  if (!getSizes().empty()) {
1328  setNameFn(getSizes().front(), "sizes");
1329  setNameFn(getStrides().front(), "strides");
1330  }
1331 }
1332 
1333 /// Helper function to perform the replacement of all constant uses of `values`
1334 /// by a materialized constant extracted from `maybeConstants`.
1335 /// `values` and `maybeConstants` are expected to have the same size.
1336 template <typename Container>
1337 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1338  Container values,
1339  ArrayRef<OpFoldResult> maybeConstants) {
1340  assert(values.size() == maybeConstants.size() &&
1341  " expected values and maybeConstants of the same size");
1342  bool atLeastOneReplacement = false;
1343  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1344  // Don't materialize a constant if there are no uses: this would indice
1345  // infinite loops in the driver.
1346  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1347  continue;
1348  assert(isa<Attribute>(maybeConstant) &&
1349  "The constified value should be either unchanged (i.e., == result) "
1350  "or a constant");
1351  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1352  loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1353  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1354  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1355  // yet.
1356  op->replaceUsesOfWith(result, constantVal);
1357  atLeastOneReplacement = true;
1358  }
1359  }
1360  return atLeastOneReplacement;
1361 }
1362 
1363 LogicalResult
1364 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1365  SmallVectorImpl<OpFoldResult> &results) {
1366  OpBuilder builder(*this);
1367 
1368  bool atLeastOneReplacement = replaceConstantUsesOf(
1369  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1370  getConstifiedMixedOffset());
1371  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1372  getConstifiedMixedSizes());
1373  atLeastOneReplacement |= replaceConstantUsesOf(
1374  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1375 
1376  return success(atLeastOneReplacement);
1377 }
1378 
1379 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1380  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1381  constifyIndexValues(values, getSource().getType().getShape());
1382  return values;
1383 }
1384 
1386 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1387  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1388  SmallVector<int64_t> staticValues;
1389  int64_t unused;
1390  LogicalResult status =
1391  getSource().getType().getStridesAndOffset(staticValues, unused);
1392  (void)status;
1393  assert(succeeded(status) && "could not get strides from type");
1394  constifyIndexValues(values, staticValues);
1395  return values;
1396 }
1397 
1398 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1399  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1400  SmallVector<OpFoldResult> values(1, offsetOfr);
1401  SmallVector<int64_t> staticValues, unused;
1402  int64_t offset;
1403  LogicalResult status =
1404  getSource().getType().getStridesAndOffset(unused, offset);
1405  (void)status;
1406  assert(succeeded(status) && "could not get offset from type");
1407  staticValues.push_back(offset);
1408  constifyIndexValues(values, staticValues);
1409  return values[0];
1410 }
1411 
1412 //===----------------------------------------------------------------------===//
1413 // GenericAtomicRMWOp
1414 //===----------------------------------------------------------------------===//
1415 
1416 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1417  Value memref, ValueRange ivs) {
1418  OpBuilder::InsertionGuard g(builder);
1419  result.addOperands(memref);
1420  result.addOperands(ivs);
1421 
1422  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1423  Type elementType = memrefType.getElementType();
1424  result.addTypes(elementType);
1425 
1426  Region *bodyRegion = result.addRegion();
1427  builder.createBlock(bodyRegion);
1428  bodyRegion->addArgument(elementType, memref.getLoc());
1429  }
1430 }
1431 
1432 LogicalResult GenericAtomicRMWOp::verify() {
1433  auto &body = getRegion();
1434  if (body.getNumArguments() != 1)
1435  return emitOpError("expected single number of entry block arguments");
1436 
1437  if (getResult().getType() != body.getArgument(0).getType())
1438  return emitOpError("expected block argument of the same type result type");
1439 
1440  bool hasSideEffects =
1441  body.walk([&](Operation *nestedOp) {
1442  if (isMemoryEffectFree(nestedOp))
1443  return WalkResult::advance();
1444  nestedOp->emitError(
1445  "body of 'memref.generic_atomic_rmw' should contain "
1446  "only operations with no side effects");
1447  return WalkResult::interrupt();
1448  })
1449  .wasInterrupted();
1450  return hasSideEffects ? failure() : success();
1451 }
1452 
1453 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1454  OperationState &result) {
1456  Type memrefType;
1458 
1459  Type indexType = parser.getBuilder().getIndexType();
1460  if (parser.parseOperand(memref) ||
1462  parser.parseColonType(memrefType) ||
1463  parser.resolveOperand(memref, memrefType, result.operands) ||
1464  parser.resolveOperands(ivs, indexType, result.operands))
1465  return failure();
1466 
1467  Region *body = result.addRegion();
1468  if (parser.parseRegion(*body, {}) ||
1469  parser.parseOptionalAttrDict(result.attributes))
1470  return failure();
1471  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1472  return success();
1473 }
1474 
1476  p << ' ' << getMemref() << "[" << getIndices()
1477  << "] : " << getMemref().getType() << ' ';
1478  p.printRegion(getRegion());
1479  p.printOptionalAttrDict((*this)->getAttrs());
1480 }
1481 
1482 //===----------------------------------------------------------------------===//
1483 // AtomicYieldOp
1484 //===----------------------------------------------------------------------===//
1485 
1486 LogicalResult AtomicYieldOp::verify() {
1487  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1488  Type resultType = getResult().getType();
1489  if (parentType != resultType)
1490  return emitOpError() << "types mismatch between yield op: " << resultType
1491  << " and its parent: " << parentType;
1492  return success();
1493 }
1494 
1495 //===----------------------------------------------------------------------===//
1496 // GlobalOp
1497 //===----------------------------------------------------------------------===//
1498 
1500  TypeAttr type,
1501  Attribute initialValue) {
1502  p << type;
1503  if (!op.isExternal()) {
1504  p << " = ";
1505  if (op.isUninitialized())
1506  p << "uninitialized";
1507  else
1508  p.printAttributeWithoutType(initialValue);
1509  }
1510 }
1511 
1512 static ParseResult
1514  Attribute &initialValue) {
1515  Type type;
1516  if (parser.parseType(type))
1517  return failure();
1518 
1519  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1520  if (!memrefType || !memrefType.hasStaticShape())
1521  return parser.emitError(parser.getNameLoc())
1522  << "type should be static shaped memref, but got " << type;
1523  typeAttr = TypeAttr::get(type);
1524 
1525  if (parser.parseOptionalEqual())
1526  return success();
1527 
1528  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1529  initialValue = UnitAttr::get(parser.getContext());
1530  return success();
1531  }
1532 
1533  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1534  if (parser.parseAttribute(initialValue, tensorType))
1535  return failure();
1536  if (!llvm::isa<ElementsAttr>(initialValue))
1537  return parser.emitError(parser.getNameLoc())
1538  << "initial value should be a unit or elements attribute";
1539  return success();
1540 }
1541 
1542 LogicalResult GlobalOp::verify() {
1543  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1544  if (!memrefType || !memrefType.hasStaticShape())
1545  return emitOpError("type should be static shaped memref, but got ")
1546  << getType();
1547 
1548  // Verify that the initial value, if present, is either a unit attribute or
1549  // an elements attribute.
1550  if (getInitialValue().has_value()) {
1551  Attribute initValue = getInitialValue().value();
1552  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1553  return emitOpError("initial value should be a unit or elements "
1554  "attribute, but got ")
1555  << initValue;
1556 
1557  // Check that the type of the initial value is compatible with the type of
1558  // the global variable.
1559  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1560  Type initType = elementsAttr.getType();
1561  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1562  if (initType != tensorType)
1563  return emitOpError("initial value expected to be of type ")
1564  << tensorType << ", but was of type " << initType;
1565  }
1566  }
1567 
1568  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1569  uint64_t alignment = *alignAttr;
1570 
1571  if (!llvm::isPowerOf2_64(alignment))
1572  return emitError() << "alignment attribute value " << alignment
1573  << " is not a power of 2";
1574  }
1575 
1576  // TODO: verify visibility for declarations.
1577  return success();
1578 }
1579 
1580 ElementsAttr GlobalOp::getConstantInitValue() {
1581  auto initVal = getInitialValue();
1582  if (getConstant() && initVal.has_value())
1583  return llvm::cast<ElementsAttr>(initVal.value());
1584  return {};
1585 }
1586 
1587 //===----------------------------------------------------------------------===//
1588 // GetGlobalOp
1589 //===----------------------------------------------------------------------===//
1590 
1591 LogicalResult
1592 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1593  // Verify that the result type is same as the type of the referenced
1594  // memref.global op.
1595  auto global =
1596  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1597  if (!global)
1598  return emitOpError("'")
1599  << getName() << "' does not reference a valid global memref";
1600 
1601  Type resultType = getResult().getType();
1602  if (global.getType() != resultType)
1603  return emitOpError("result type ")
1604  << resultType << " does not match type " << global.getType()
1605  << " of the global memref @" << getName();
1606  return success();
1607 }
1608 
1609 //===----------------------------------------------------------------------===//
1610 // LoadOp
1611 //===----------------------------------------------------------------------===//
1612 
1613 LogicalResult LoadOp::verify() {
1614  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1615  return emitOpError("incorrect number of indices for load, expected ")
1616  << getMemRefType().getRank() << " but got " << getIndices().size();
1617  }
1618  return success();
1619 }
1620 
1621 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1622  /// load(memrefcast) -> load
1623  if (succeeded(foldMemRefCast(*this)))
1624  return getResult();
1625  return OpFoldResult();
1626 }
1627 
1628 //===----------------------------------------------------------------------===//
1629 // MemorySpaceCastOp
1630 //===----------------------------------------------------------------------===//
1631 
1632 void MemorySpaceCastOp::getAsmResultNames(
1633  function_ref<void(Value, StringRef)> setNameFn) {
1634  setNameFn(getResult(), "memspacecast");
1635 }
1636 
1637 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1638  if (inputs.size() != 1 || outputs.size() != 1)
1639  return false;
1640  Type a = inputs.front(), b = outputs.front();
1641  auto aT = llvm::dyn_cast<MemRefType>(a);
1642  auto bT = llvm::dyn_cast<MemRefType>(b);
1643 
1644  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1645  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1646 
1647  if (aT && bT) {
1648  if (aT.getElementType() != bT.getElementType())
1649  return false;
1650  if (aT.getLayout() != bT.getLayout())
1651  return false;
1652  if (aT.getShape() != bT.getShape())
1653  return false;
1654  return true;
1655  }
1656  if (uaT && ubT) {
1657  return uaT.getElementType() == ubT.getElementType();
1658  }
1659  return false;
1660 }
1661 
1662 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1663  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1664  // t2)
1665  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1666  getSourceMutable().assign(parentCast.getSource());
1667  return getResult();
1668  }
1669  return Value{};
1670 }
1671 
1672 //===----------------------------------------------------------------------===//
1673 // PrefetchOp
1674 //===----------------------------------------------------------------------===//
1675 
1677  p << " " << getMemref() << '[';
1679  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1680  p << ", locality<" << getLocalityHint();
1681  p << ">, " << (getIsDataCache() ? "data" : "instr");
1683  (*this)->getAttrs(),
1684  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1685  p << " : " << getMemRefType();
1686 }
1687 
1688 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1689  OpAsmParser::UnresolvedOperand memrefInfo;
1691  IntegerAttr localityHint;
1692  MemRefType type;
1693  StringRef readOrWrite, cacheType;
1694 
1695  auto indexTy = parser.getBuilder().getIndexType();
1696  auto i32Type = parser.getBuilder().getIntegerType(32);
1697  if (parser.parseOperand(memrefInfo) ||
1698  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1699  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1700  parser.parseComma() || parser.parseKeyword("locality") ||
1701  parser.parseLess() ||
1702  parser.parseAttribute(localityHint, i32Type, "localityHint",
1703  result.attributes) ||
1704  parser.parseGreater() || parser.parseComma() ||
1705  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1706  parser.resolveOperand(memrefInfo, type, result.operands) ||
1707  parser.resolveOperands(indexInfo, indexTy, result.operands))
1708  return failure();
1709 
1710  if (readOrWrite != "read" && readOrWrite != "write")
1711  return parser.emitError(parser.getNameLoc(),
1712  "rw specifier has to be 'read' or 'write'");
1713  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1714  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1715 
1716  if (cacheType != "data" && cacheType != "instr")
1717  return parser.emitError(parser.getNameLoc(),
1718  "cache type has to be 'data' or 'instr'");
1719 
1720  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1721  parser.getBuilder().getBoolAttr(cacheType == "data"));
1722 
1723  return success();
1724 }
1725 
1726 LogicalResult PrefetchOp::verify() {
1727  if (getNumOperands() != 1 + getMemRefType().getRank())
1728  return emitOpError("too few indices");
1729 
1730  return success();
1731 }
1732 
1733 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1734  SmallVectorImpl<OpFoldResult> &results) {
1735  // prefetch(memrefcast) -> prefetch
1736  return foldMemRefCast(*this);
1737 }
1738 
1739 //===----------------------------------------------------------------------===//
1740 // RankOp
1741 //===----------------------------------------------------------------------===//
1742 
1743 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1744  // Constant fold rank when the rank of the operand is known.
1745  auto type = getOperand().getType();
1746  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1747  if (shapedType && shapedType.hasRank())
1748  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1749  return IntegerAttr();
1750 }
1751 
1752 //===----------------------------------------------------------------------===//
1753 // ReinterpretCastOp
1754 //===----------------------------------------------------------------------===//
1755 
1756 void ReinterpretCastOp::getAsmResultNames(
1757  function_ref<void(Value, StringRef)> setNameFn) {
1758  setNameFn(getResult(), "reinterpret_cast");
1759 }
1760 
1761 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1762 /// `staticSizes` and `staticStrides` are automatically filled with
1763 /// source-memref-rank sentinel values that encode dynamic entries.
1764 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1765  MemRefType resultType, Value source,
1766  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1767  ArrayRef<OpFoldResult> strides,
1768  ArrayRef<NamedAttribute> attrs) {
1769  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1770  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1771  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1772  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1773  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1774  result.addAttributes(attrs);
1775  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1776  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1777  b.getDenseI64ArrayAttr(staticSizes),
1778  b.getDenseI64ArrayAttr(staticStrides));
1779 }
1780 
1781 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1782  Value source, OpFoldResult offset,
1783  ArrayRef<OpFoldResult> sizes,
1784  ArrayRef<OpFoldResult> strides,
1785  ArrayRef<NamedAttribute> attrs) {
1786  auto sourceType = cast<BaseMemRefType>(source.getType());
1787  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1788  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1789  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1790  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1791  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1792  auto stridedLayout = StridedLayoutAttr::get(
1793  b.getContext(), staticOffsets.front(), staticStrides);
1794  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1795  stridedLayout, sourceType.getMemorySpace());
1796  build(b, result, resultType, source, offset, sizes, strides, attrs);
1797 }
1798 
1799 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1800  MemRefType resultType, Value source,
1801  int64_t offset, ArrayRef<int64_t> sizes,
1802  ArrayRef<int64_t> strides,
1803  ArrayRef<NamedAttribute> attrs) {
1804  SmallVector<OpFoldResult> sizeValues =
1805  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1806  return b.getI64IntegerAttr(v);
1807  }));
1808  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1809  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1810  return b.getI64IntegerAttr(v);
1811  }));
1812  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1813  strideValues, attrs);
1814 }
1815 
1816 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1817  MemRefType resultType, Value source, Value offset,
1818  ValueRange sizes, ValueRange strides,
1819  ArrayRef<NamedAttribute> attrs) {
1820  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1821  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1822  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1823  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1824  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1825 }
1826 
1827 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1828 // completed automatically, like we have for subview and extract_slice.
1829 LogicalResult ReinterpretCastOp::verify() {
1830  // The source and result memrefs should be in the same memory space.
1831  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1832  auto resultType = llvm::cast<MemRefType>(getType());
1833  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1834  return emitError("different memory spaces specified for source type ")
1835  << srcType << " and result memref type " << resultType;
1836  if (srcType.getElementType() != resultType.getElementType())
1837  return emitError("different element types specified for source type ")
1838  << srcType << " and result memref type " << resultType;
1839 
1840  // Match sizes in result memref type and in static_sizes attribute.
1841  for (auto [idx, resultSize, expectedSize] :
1842  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1843  if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
1844  return emitError("expected result type with size = ")
1845  << (ShapedType::isDynamic(expectedSize)
1846  ? std::string("dynamic")
1847  : std::to_string(expectedSize))
1848  << " instead of " << resultSize << " in dim = " << idx;
1849  }
1850 
1851  // Match offset and strides in static_offset and static_strides attributes. If
1852  // result memref type has no affine map specified, this will assume an
1853  // identity layout.
1854  int64_t resultOffset;
1855  SmallVector<int64_t, 4> resultStrides;
1856  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1857  return emitError("expected result type to have strided layout but found ")
1858  << resultType;
1859 
1860  // Match offset in result memref type and in static_offsets attribute.
1861  int64_t expectedOffset = getStaticOffsets().front();
1862  if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
1863  return emitError("expected result type with offset = ")
1864  << (ShapedType::isDynamic(expectedOffset)
1865  ? std::string("dynamic")
1866  : std::to_string(expectedOffset))
1867  << " instead of " << resultOffset;
1868 
1869  // Match strides in result memref type and in static_strides attribute.
1870  for (auto [idx, resultStride, expectedStride] :
1871  llvm::enumerate(resultStrides, getStaticStrides())) {
1872  if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
1873  return emitError("expected result type with stride = ")
1874  << (ShapedType::isDynamic(expectedStride)
1875  ? std::string("dynamic")
1876  : std::to_string(expectedStride))
1877  << " instead of " << resultStride << " in dim = " << idx;
1878  }
1879 
1880  return success();
1881 }
1882 
1883 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1884  Value src = getSource();
1885  auto getPrevSrc = [&]() -> Value {
1886  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1887  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1888  return prev.getSource();
1889 
1890  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1891  if (auto prev = src.getDefiningOp<CastOp>())
1892  return prev.getSource();
1893 
1894  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1895  // are 0.
1896  if (auto prev = src.getDefiningOp<SubViewOp>())
1897  if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1898  return isConstantIntValue(val, 0);
1899  }))
1900  return prev.getSource();
1901 
1902  return nullptr;
1903  };
1904 
1905  if (auto prevSrc = getPrevSrc()) {
1906  getSourceMutable().assign(prevSrc);
1907  return getResult();
1908  }
1909 
1910  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1911  if (!ShapedType::isDynamicShape(getType().getShape()) &&
1912  src.getType() == getType() && getStaticOffsets().front() == 0) {
1913  return src;
1914  }
1915 
1916  return nullptr;
1917 }
1918 
1919 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1921  constifyIndexValues(values, getType().getShape());
1922  return values;
1923 }
1924 
1925 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1926  SmallVector<OpFoldResult> values = getMixedStrides();
1927  SmallVector<int64_t> staticValues;
1928  int64_t unused;
1929  LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
1930  (void)status;
1931  assert(succeeded(status) && "could not get strides from type");
1932  constifyIndexValues(values, staticValues);
1933  return values;
1934 }
1935 
1936 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1937  SmallVector<OpFoldResult> values = getMixedOffsets();
1938  assert(values.size() == 1 &&
1939  "reinterpret_cast must have one and only one offset");
1940  SmallVector<int64_t> staticValues, unused;
1941  int64_t offset;
1942  LogicalResult status = getType().getStridesAndOffset(unused, offset);
1943  (void)status;
1944  assert(succeeded(status) && "could not get offset from type");
1945  staticValues.push_back(offset);
1946  constifyIndexValues(values, staticValues);
1947  return values[0];
1948 }
1949 
1950 namespace {
1951 /// Replace the sequence:
1952 /// ```
1953 /// base, offset, sizes, strides = extract_strided_metadata src
1954 /// dst = reinterpret_cast base to offset, sizes, strides
1955 /// ```
1956 /// With
1957 ///
1958 /// ```
1959 /// dst = memref.cast src
1960 /// ```
1961 ///
1962 /// Note: The cast operation is only inserted when the type of dst and src
1963 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
1964 ///
1965 /// This pattern also matches when the offset, sizes, and strides don't come
1966 /// directly from the `extract_strided_metadata`'s results but it can be
1967 /// statically proven that they would hold the same values.
1968 ///
1969 /// For instance, the following sequence would be replaced:
1970 /// ```
1971 /// base, offset, sizes, strides =
1972 /// extract_strided_metadata memref : memref<3x4xty>
1973 /// dst = reinterpret_cast base to 0, [3, 4], strides
1974 /// ```
1975 /// Because we know (thanks to the type of the input memref) that variable
1976 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
1977 ///
1978 /// Similarly, the following sequence would be replaced:
1979 /// ```
1980 /// c0 = arith.constant 0
1981 /// c4 = arith.constant 4
1982 /// base, offset, sizes, strides =
1983 /// extract_strided_metadata memref : memref<3x4xty>
1984 /// dst = reinterpret_cast base to c0, [3, c4], strides
1985 /// ```
1986 /// Because we know that `offset`and `c0` will hold 0
1987 /// and `c4` will hold 4.
1988 ///
1989 /// If the pattern above does not match, the input of the
1990 /// extract_strided_metadata is always folded into the input of the
1991 /// reinterpret_cast operator. This allows for dead code elimination to get rid
1992 /// of the extract_strided_metadata in some cases.
1993 struct ReinterpretCastOpExtractStridedMetadataFolder
1994  : public OpRewritePattern<ReinterpretCastOp> {
1995 public:
1997 
1998  LogicalResult matchAndRewrite(ReinterpretCastOp op,
1999  PatternRewriter &rewriter) const override {
2000  auto extractStridedMetadata =
2001  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2002  if (!extractStridedMetadata)
2003  return failure();
2004 
2005  // Check if the reinterpret cast reconstructs a memref with the exact same
2006  // properties as the extract strided metadata.
2007  auto isReinterpretCastNoop = [&]() -> bool {
2008  // First, check that the strides are the same.
2009  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2010  op.getConstifiedMixedStrides()))
2011  return false;
2012 
2013  // Second, check the sizes.
2014  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2015  op.getConstifiedMixedSizes()))
2016  return false;
2017 
2018  // Finally, check the offset.
2019  assert(op.getMixedOffsets().size() == 1 &&
2020  "reinterpret_cast with more than one offset should have been "
2021  "rejected by the verifier");
2022  return extractStridedMetadata.getConstifiedMixedOffset() ==
2023  op.getConstifiedMixedOffset();
2024  };
2025 
2026  if (!isReinterpretCastNoop()) {
2027  // If the extract_strided_metadata / reinterpret_cast pair can't be
2028  // completely folded, then we could fold the input of the
2029  // extract_strided_metadata into the input of the reinterpret_cast
2030  // input. For some cases (e.g., static dimensions) the
2031  // the extract_strided_metadata is eliminated by dead code elimination.
2032  //
2033  // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2034  //
2035  // We can always fold the input of a extract_strided_metadata operator
2036  // to the input of a reinterpret_cast operator, because they point to
2037  // the same memory. Note that the reinterpret_cast does not use the
2038  // layout of its input memref, only its base memory pointer which is
2039  // the same as the base pointer returned by the extract_strided_metadata
2040  // operator and the base pointer of the extract_strided_metadata memref
2041  // input.
2042  rewriter.modifyOpInPlace(op, [&]() {
2043  op.getSourceMutable().assign(extractStridedMetadata.getSource());
2044  });
2045  return success();
2046  }
2047 
2048  // At this point, we know that the back and forth between extract strided
2049  // metadata and reinterpret cast is a noop. However, the final type of the
2050  // reinterpret cast may not be exactly the same as the original memref.
2051  // E.g., it could be changing a dimension from static to dynamic. Check that
2052  // here and add a cast if necessary.
2053  Type srcTy = extractStridedMetadata.getSource().getType();
2054  if (srcTy == op.getResult().getType())
2055  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2056  else
2057  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2058  extractStridedMetadata.getSource());
2059 
2060  return success();
2061  }
2062 };
2063 } // namespace
2064 
2065 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2066  MLIRContext *context) {
2067  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2068 }
2069 
2070 //===----------------------------------------------------------------------===//
2071 // Reassociative reshape ops
2072 //===----------------------------------------------------------------------===//
2073 
2074 void CollapseShapeOp::getAsmResultNames(
2075  function_ref<void(Value, StringRef)> setNameFn) {
2076  setNameFn(getResult(), "collapse_shape");
2077 }
2078 
2079 void ExpandShapeOp::getAsmResultNames(
2080  function_ref<void(Value, StringRef)> setNameFn) {
2081  setNameFn(getResult(), "expand_shape");
2082 }
2083 
2084 LogicalResult ExpandShapeOp::reifyResultShapes(
2085  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2086  reifiedResultShapes = {
2087  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2088  return success();
2089 }
2090 
2091 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2092 /// result and operand. Layout maps are verified separately.
2093 ///
2094 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2095 /// allowed in a reassocation group.
2096 static LogicalResult
2098  ArrayRef<int64_t> expandedShape,
2099  ArrayRef<ReassociationIndices> reassociation,
2100  bool allowMultipleDynamicDimsPerGroup) {
2101  // There must be one reassociation group per collapsed dimension.
2102  if (collapsedShape.size() != reassociation.size())
2103  return op->emitOpError("invalid number of reassociation groups: found ")
2104  << reassociation.size() << ", expected " << collapsedShape.size();
2105 
2106  // The next expected expanded dimension index (while iterating over
2107  // reassociation indices).
2108  int64_t nextDim = 0;
2109  for (const auto &it : llvm::enumerate(reassociation)) {
2110  ReassociationIndices group = it.value();
2111  int64_t collapsedDim = it.index();
2112 
2113  bool foundDynamic = false;
2114  for (int64_t expandedDim : group) {
2115  if (expandedDim != nextDim++)
2116  return op->emitOpError("reassociation indices must be contiguous");
2117 
2118  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2119  return op->emitOpError("reassociation index ")
2120  << expandedDim << " is out of bounds";
2121 
2122  // Check if there are multiple dynamic dims in a reassociation group.
2123  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2124  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2125  return op->emitOpError(
2126  "at most one dimension in a reassociation group may be dynamic");
2127  foundDynamic = true;
2128  }
2129  }
2130 
2131  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2132  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2133  return op->emitOpError("collapsed dim (")
2134  << collapsedDim
2135  << ") must be dynamic if and only if reassociation group is "
2136  "dynamic";
2137 
2138  // If all dims in the reassociation group are static, the size of the
2139  // collapsed dim can be verified.
2140  if (!foundDynamic) {
2141  int64_t groupSize = 1;
2142  for (int64_t expandedDim : group)
2143  groupSize *= expandedShape[expandedDim];
2144  if (groupSize != collapsedShape[collapsedDim])
2145  return op->emitOpError("collapsed dim size (")
2146  << collapsedShape[collapsedDim]
2147  << ") must equal reassociation group size (" << groupSize << ")";
2148  }
2149  }
2150 
2151  if (collapsedShape.empty()) {
2152  // Rank 0: All expanded dimensions must be 1.
2153  for (int64_t d : expandedShape)
2154  if (d != 1)
2155  return op->emitOpError(
2156  "rank 0 memrefs can only be extended/collapsed with/from ones");
2157  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2158  // Rank >= 1: Number of dimensions among all reassociation groups must match
2159  // the result memref rank.
2160  return op->emitOpError("expanded rank (")
2161  << expandedShape.size()
2162  << ") inconsistent with number of reassociation indices (" << nextDim
2163  << ")";
2164  }
2165 
2166  return success();
2167 }
2168 
2169 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2170  return getSymbolLessAffineMaps(getReassociationExprs());
2171 }
2172 
2173 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2175  getReassociationIndices());
2176 }
2177 
2178 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2179  return getSymbolLessAffineMaps(getReassociationExprs());
2180 }
2181 
2182 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2184  getReassociationIndices());
2185 }
2186 
2187 /// Compute the layout map after expanding a given source MemRef type with the
2188 /// specified reassociation indices.
2189 static FailureOr<StridedLayoutAttr>
2190 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2191  ArrayRef<ReassociationIndices> reassociation) {
2192  int64_t srcOffset;
2193  SmallVector<int64_t> srcStrides;
2194  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2195  return failure();
2196  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2197 
2198  // 1-1 mapping between srcStrides and reassociation packs.
2199  // Each srcStride starts with the given value and gets expanded according to
2200  // the proper entries in resultShape.
2201  // Example:
2202  // srcStrides = [10000, 1 , 100 ],
2203  // reassociations = [ [0], [1], [2, 3, 4]],
2204  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2205  // -> For the purpose of stride calculation, the useful sizes are:
2206  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2207  // resultStrides = [10000, 1, 600, 200, 100]
2208  // Note that a stride does not get expanded along the first entry of each
2209  // shape pack.
2210  SmallVector<int64_t> reverseResultStrides;
2211  reverseResultStrides.reserve(resultShape.size());
2212  unsigned shapeIndex = resultShape.size() - 1;
2213  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2214  ReassociationIndices reassoc = std::get<0>(it);
2215  int64_t currentStrideToExpand = std::get<1>(it);
2216  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2217  reverseResultStrides.push_back(currentStrideToExpand);
2218  currentStrideToExpand =
2219  (SaturatedInteger::wrap(currentStrideToExpand) *
2220  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2221  .asInteger();
2222  }
2223  }
2224  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2225  resultStrides.resize(resultShape.size(), 1);
2226  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2227 }
2228 
2229 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2230  MemRefType srcType, ArrayRef<int64_t> resultShape,
2231  ArrayRef<ReassociationIndices> reassociation) {
2232  if (srcType.getLayout().isIdentity()) {
2233  // If the source is contiguous (i.e., no layout map specified), so is the
2234  // result.
2235  MemRefLayoutAttrInterface layout;
2236  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2237  srcType.getMemorySpace());
2238  }
2239 
2240  // Source may not be contiguous. Compute the layout map.
2241  FailureOr<StridedLayoutAttr> computedLayout =
2242  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2243  if (failed(computedLayout))
2244  return failure();
2245  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2246  srcType.getMemorySpace());
2247 }
2248 
2249 FailureOr<SmallVector<OpFoldResult>>
2250 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2251  MemRefType expandedType,
2252  ArrayRef<ReassociationIndices> reassociation,
2253  ArrayRef<OpFoldResult> inputShape) {
2254  std::optional<SmallVector<OpFoldResult>> outputShape =
2255  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2256  inputShape);
2257  if (!outputShape)
2258  return failure();
2259  return *outputShape;
2260 }
2261 
2262 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2263  Type resultType, Value src,
2264  ArrayRef<ReassociationIndices> reassociation,
2265  ArrayRef<OpFoldResult> outputShape) {
2266  auto [staticOutputShape, dynamicOutputShape] =
2268  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2269  getReassociationIndicesAttribute(builder, reassociation),
2270  dynamicOutputShape, staticOutputShape);
2271 }
2272 
2273 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2274  Type resultType, Value src,
2275  ArrayRef<ReassociationIndices> reassociation) {
2276  SmallVector<OpFoldResult> inputShape =
2277  getMixedSizes(builder, result.location, src);
2278  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2279  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2280  builder, result.location, memrefResultTy, reassociation, inputShape);
2281  // Failure of this assertion usually indicates presence of multiple
2282  // dynamic dimensions in the same reassociation group.
2283  assert(succeeded(outputShape) && "unable to infer output shape");
2284  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2285 }
2286 
2287 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2288  ArrayRef<int64_t> resultShape, Value src,
2289  ArrayRef<ReassociationIndices> reassociation) {
2290  // Only ranked memref source values are supported.
2291  auto srcType = llvm::cast<MemRefType>(src.getType());
2292  FailureOr<MemRefType> resultType =
2293  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2294  // Failure of this assertion usually indicates a problem with the source
2295  // type, e.g., could not get strides/offset.
2296  assert(succeeded(resultType) && "could not compute layout");
2297  build(builder, result, *resultType, src, reassociation);
2298 }
2299 
2300 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2301  ArrayRef<int64_t> resultShape, Value src,
2302  ArrayRef<ReassociationIndices> reassociation,
2303  ArrayRef<OpFoldResult> outputShape) {
2304  // Only ranked memref source values are supported.
2305  auto srcType = llvm::cast<MemRefType>(src.getType());
2306  FailureOr<MemRefType> resultType =
2307  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2308  // Failure of this assertion usually indicates a problem with the source
2309  // type, e.g., could not get strides/offset.
2310  assert(succeeded(resultType) && "could not compute layout");
2311  build(builder, result, *resultType, src, reassociation, outputShape);
2312 }
2313 
2314 LogicalResult ExpandShapeOp::verify() {
2315  MemRefType srcType = getSrcType();
2316  MemRefType resultType = getResultType();
2317 
2318  if (srcType.getRank() > resultType.getRank()) {
2319  auto r0 = srcType.getRank();
2320  auto r1 = resultType.getRank();
2321  return emitOpError("has source rank ")
2322  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2323  << r0 << " > " << r1 << ").";
2324  }
2325 
2326  // Verify result shape.
2327  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2328  resultType.getShape(),
2329  getReassociationIndices(),
2330  /*allowMultipleDynamicDimsPerGroup=*/true)))
2331  return failure();
2332 
2333  // Compute expected result type (including layout map).
2334  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2335  srcType, resultType.getShape(), getReassociationIndices());
2336  if (failed(expectedResultType))
2337  return emitOpError("invalid source layout map");
2338 
2339  // Check actual result type.
2340  if (*expectedResultType != resultType)
2341  return emitOpError("expected expanded type to be ")
2342  << *expectedResultType << " but found " << resultType;
2343 
2344  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2345  return emitOpError("expected number of static shape bounds to be equal to "
2346  "the output rank (")
2347  << resultType.getRank() << ") but found "
2348  << getStaticOutputShape().size() << " inputs instead";
2349 
2350  if ((int64_t)getOutputShape().size() !=
2351  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2352  return emitOpError("mismatch in dynamic dims in output_shape and "
2353  "static_output_shape: static_output_shape has ")
2354  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2355  << " dynamic dims while output_shape has " << getOutputShape().size()
2356  << " values";
2357 
2358  // Verify if provided output shapes are in agreement with output type.
2359  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2360  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2361  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2362  if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2363  return emitOpError("invalid output shape provided at pos ") << pos;
2364  }
2365  }
2366 
2367  return success();
2368 }
2369 
2370 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2371  MLIRContext *context) {
2372  results.add<
2375 }
2376 
2377 /// Compute the layout map after collapsing a given source MemRef type with the
2378 /// specified reassociation indices.
2379 ///
2380 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2381 /// not possible to check this by inspecting a MemRefType in the general case.
2382 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2383 /// be valid (and thus accepted by this function) unless `strict = true`.
2384 static FailureOr<StridedLayoutAttr>
2385 computeCollapsedLayoutMap(MemRefType srcType,
2386  ArrayRef<ReassociationIndices> reassociation,
2387  bool strict = false) {
2388  int64_t srcOffset;
2389  SmallVector<int64_t> srcStrides;
2390  auto srcShape = srcType.getShape();
2391  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2392  return failure();
2393 
2394  // The result stride of a reassociation group is the stride of the last entry
2395  // of the reassociation. (TODO: Should be the minimum stride in the
2396  // reassociation because strides are not necessarily sorted. E.g., when using
2397  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2398  // strides are meaningless and could have any arbitrary value.
2399  SmallVector<int64_t> resultStrides;
2400  resultStrides.reserve(reassociation.size());
2401  for (const ReassociationIndices &reassoc : reassociation) {
2402  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2403  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2404  ref = ref.drop_back();
2405  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2406  resultStrides.push_back(srcStrides[ref.back()]);
2407  } else {
2408  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2409  // the corresponding stride may have to be skipped. (See above comment.)
2410  // Therefore, the result stride cannot be statically determined and must
2411  // be dynamic.
2412  resultStrides.push_back(ShapedType::kDynamic);
2413  }
2414  }
2415 
2416  // Validate that each reassociation group is contiguous.
2417  unsigned resultStrideIndex = resultStrides.size() - 1;
2418  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2419  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2420  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2421  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2422  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2423 
2424  // Both source and result stride must have the same static value. In that
2425  // case, we can be sure, that the dimensions are collapsible (because they
2426  // are contiguous).
2427  // If `strict = false` (default during op verification), we accept cases
2428  // where one or both strides are dynamic. This is best effort: We reject
2429  // ops where obviously non-contiguous dims are collapsed, but accept ops
2430  // where we cannot be sure statically. Such ops may fail at runtime. See
2431  // the op documentation for details.
2432  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2433  if (strict && (stride.saturated || srcStride.saturated))
2434  return failure();
2435 
2436  // Dimensions of size 1 should be skipped, because their strides are
2437  // meaningless and could have any arbitrary value.
2438  if (srcShape[idx - 1] == 1)
2439  continue;
2440 
2441  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2442  return failure();
2443  }
2444  }
2445  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2446 }
2447 
2448 bool CollapseShapeOp::isGuaranteedCollapsible(
2449  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2450  // MemRefs with identity layout are always collapsible.
2451  if (srcType.getLayout().isIdentity())
2452  return true;
2453 
2454  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2455  /*strict=*/true));
2456 }
2457 
2458 MemRefType CollapseShapeOp::computeCollapsedType(
2459  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2460  SmallVector<int64_t> resultShape;
2461  resultShape.reserve(reassociation.size());
2462  for (const ReassociationIndices &group : reassociation) {
2463  auto groupSize = SaturatedInteger::wrap(1);
2464  for (int64_t srcDim : group)
2465  groupSize =
2466  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2467  resultShape.push_back(groupSize.asInteger());
2468  }
2469 
2470  if (srcType.getLayout().isIdentity()) {
2471  // If the source is contiguous (i.e., no layout map specified), so is the
2472  // result.
2473  MemRefLayoutAttrInterface layout;
2474  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2475  srcType.getMemorySpace());
2476  }
2477 
2478  // Source may not be fully contiguous. Compute the layout map.
2479  // Note: Dimensions that are collapsed into a single dim are assumed to be
2480  // contiguous.
2481  FailureOr<StridedLayoutAttr> computedLayout =
2482  computeCollapsedLayoutMap(srcType, reassociation);
2483  assert(succeeded(computedLayout) &&
2484  "invalid source layout map or collapsing non-contiguous dims");
2485  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2486  srcType.getMemorySpace());
2487 }
2488 
2489 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2490  ArrayRef<ReassociationIndices> reassociation,
2491  ArrayRef<NamedAttribute> attrs) {
2492  auto srcType = llvm::cast<MemRefType>(src.getType());
2493  MemRefType resultType =
2494  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2496  getReassociationIndicesAttribute(b, reassociation));
2497  build(b, result, resultType, src, attrs);
2498 }
2499 
2500 LogicalResult CollapseShapeOp::verify() {
2501  MemRefType srcType = getSrcType();
2502  MemRefType resultType = getResultType();
2503 
2504  if (srcType.getRank() < resultType.getRank()) {
2505  auto r0 = srcType.getRank();
2506  auto r1 = resultType.getRank();
2507  return emitOpError("has source rank ")
2508  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2509  << r0 << " < " << r1 << ").";
2510  }
2511 
2512  // Verify result shape.
2513  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2514  srcType.getShape(), getReassociationIndices(),
2515  /*allowMultipleDynamicDimsPerGroup=*/true)))
2516  return failure();
2517 
2518  // Compute expected result type (including layout map).
2519  MemRefType expectedResultType;
2520  if (srcType.getLayout().isIdentity()) {
2521  // If the source is contiguous (i.e., no layout map specified), so is the
2522  // result.
2523  MemRefLayoutAttrInterface layout;
2524  expectedResultType =
2525  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2526  srcType.getMemorySpace());
2527  } else {
2528  // Source may not be fully contiguous. Compute the layout map.
2529  // Note: Dimensions that are collapsed into a single dim are assumed to be
2530  // contiguous.
2531  FailureOr<StridedLayoutAttr> computedLayout =
2532  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2533  if (failed(computedLayout))
2534  return emitOpError(
2535  "invalid source layout map or collapsing non-contiguous dims");
2536  expectedResultType =
2537  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2538  *computedLayout, srcType.getMemorySpace());
2539  }
2540 
2541  if (expectedResultType != resultType)
2542  return emitOpError("expected collapsed type to be ")
2543  << expectedResultType << " but found " << resultType;
2544 
2545  return success();
2546 }
2547 
2549  : public OpRewritePattern<CollapseShapeOp> {
2550 public:
2552 
2553  LogicalResult matchAndRewrite(CollapseShapeOp op,
2554  PatternRewriter &rewriter) const override {
2555  auto cast = op.getOperand().getDefiningOp<CastOp>();
2556  if (!cast)
2557  return failure();
2558 
2559  if (!CastOp::canFoldIntoConsumerOp(cast))
2560  return failure();
2561 
2562  Type newResultType = CollapseShapeOp::computeCollapsedType(
2563  llvm::cast<MemRefType>(cast.getOperand().getType()),
2564  op.getReassociationIndices());
2565 
2566  if (newResultType == op.getResultType()) {
2567  rewriter.modifyOpInPlace(
2568  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2569  } else {
2570  Value newOp = rewriter.create<CollapseShapeOp>(
2571  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2572  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2573  }
2574  return success();
2575  }
2576 };
2577 
2578 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2579  MLIRContext *context) {
2580  results.add<
2582  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2583  memref::DimOp, MemRefType>,
2585 }
2586 
2587 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2588  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2589  adaptor.getOperands());
2590 }
2591 
2592 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2593  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2594  adaptor.getOperands());
2595 }
2596 
2597 //===----------------------------------------------------------------------===//
2598 // ReshapeOp
2599 //===----------------------------------------------------------------------===//
2600 
2601 void ReshapeOp::getAsmResultNames(
2602  function_ref<void(Value, StringRef)> setNameFn) {
2603  setNameFn(getResult(), "reshape");
2604 }
2605 
2606 LogicalResult ReshapeOp::verify() {
2607  Type operandType = getSource().getType();
2608  Type resultType = getResult().getType();
2609 
2610  Type operandElementType =
2611  llvm::cast<ShapedType>(operandType).getElementType();
2612  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2613  if (operandElementType != resultElementType)
2614  return emitOpError("element types of source and destination memref "
2615  "types should be the same");
2616 
2617  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2618  if (!operandMemRefType.getLayout().isIdentity())
2619  return emitOpError("source memref type should have identity affine map");
2620 
2621  int64_t shapeSize =
2622  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2623  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2624  if (resultMemRefType) {
2625  if (!resultMemRefType.getLayout().isIdentity())
2626  return emitOpError("result memref type should have identity affine map");
2627  if (shapeSize == ShapedType::kDynamic)
2628  return emitOpError("cannot use shape operand with dynamic length to "
2629  "reshape to statically-ranked memref type");
2630  if (shapeSize != resultMemRefType.getRank())
2631  return emitOpError(
2632  "length of shape operand differs from the result's memref rank");
2633  }
2634  return success();
2635 }
2636 
2637 //===----------------------------------------------------------------------===//
2638 // StoreOp
2639 //===----------------------------------------------------------------------===//
2640 
2641 LogicalResult StoreOp::verify() {
2642  if (getNumOperands() != 2 + getMemRefType().getRank())
2643  return emitOpError("store index operand count not equal to memref rank");
2644 
2645  return success();
2646 }
2647 
2648 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2649  SmallVectorImpl<OpFoldResult> &results) {
2650  /// store(memrefcast) -> store
2651  return foldMemRefCast(*this, getValueToStore());
2652 }
2653 
2654 //===----------------------------------------------------------------------===//
2655 // SubViewOp
2656 //===----------------------------------------------------------------------===//
2657 
2658 void SubViewOp::getAsmResultNames(
2659  function_ref<void(Value, StringRef)> setNameFn) {
2660  setNameFn(getResult(), "subview");
2661 }
2662 
2663 /// A subview result type can be fully inferred from the source type and the
2664 /// static representation of offsets, sizes and strides. Special sentinels
2665 /// encode the dynamic case.
2666 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2667  ArrayRef<int64_t> staticOffsets,
2668  ArrayRef<int64_t> staticSizes,
2669  ArrayRef<int64_t> staticStrides) {
2670  unsigned rank = sourceMemRefType.getRank();
2671  (void)rank;
2672  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2673  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2674  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2675 
2676  // Extract source offset and strides.
2677  auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2678 
2679  // Compute target offset whose value is:
2680  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2681  int64_t targetOffset = sourceOffset;
2682  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2683  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2684  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2685  SaturatedInteger::wrap(staticOffset) *
2686  SaturatedInteger::wrap(sourceStride))
2687  .asInteger();
2688  }
2689 
2690  // Compute target stride whose value is:
2691  // `sourceStrides_i * staticStrides_i`.
2692  SmallVector<int64_t, 4> targetStrides;
2693  targetStrides.reserve(staticOffsets.size());
2694  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2695  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2696  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2697  SaturatedInteger::wrap(staticStride))
2698  .asInteger());
2699  }
2700 
2701  // The type is now known.
2702  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2703  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2704  targetOffset, targetStrides),
2705  sourceMemRefType.getMemorySpace());
2706 }
2707 
2708 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2709  ArrayRef<OpFoldResult> offsets,
2710  ArrayRef<OpFoldResult> sizes,
2711  ArrayRef<OpFoldResult> strides) {
2712  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2713  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2714  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2715  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2716  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2717  if (!hasValidSizesOffsets(staticOffsets))
2718  return {};
2719  if (!hasValidSizesOffsets(staticSizes))
2720  return {};
2721  if (!hasValidStrides(staticStrides))
2722  return {};
2723  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2724  staticSizes, staticStrides);
2725 }
2726 
2727 MemRefType SubViewOp::inferRankReducedResultType(
2728  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2729  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2730  ArrayRef<int64_t> strides) {
2731  MemRefType inferredType =
2732  inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2733  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2734  "expected ");
2735  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2736  return inferredType;
2737 
2738  // Compute which dimensions are dropped.
2739  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2740  computeRankReductionMask(inferredType.getShape(), resultShape);
2741  assert(dimsToProject.has_value() && "invalid rank reduction");
2742 
2743  // Compute the layout and result type.
2744  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2745  SmallVector<int64_t> rankReducedStrides;
2746  rankReducedStrides.reserve(resultShape.size());
2747  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2748  if (!dimsToProject->contains(idx))
2749  rankReducedStrides.push_back(value);
2750  }
2751  return MemRefType::get(resultShape, inferredType.getElementType(),
2752  StridedLayoutAttr::get(inferredLayout.getContext(),
2753  inferredLayout.getOffset(),
2754  rankReducedStrides),
2755  inferredType.getMemorySpace());
2756 }
2757 
2758 MemRefType SubViewOp::inferRankReducedResultType(
2759  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2761  ArrayRef<OpFoldResult> strides) {
2762  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2763  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2764  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2765  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2766  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2767  return SubViewOp::inferRankReducedResultType(
2768  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2769  staticStrides);
2770 }
2771 
2772 // Build a SubViewOp with mixed static and dynamic entries and custom result
2773 // type. If the type passed is nullptr, it is inferred.
2774 void SubViewOp::build(OpBuilder &b, OperationState &result,
2775  MemRefType resultType, Value source,
2776  ArrayRef<OpFoldResult> offsets,
2777  ArrayRef<OpFoldResult> sizes,
2778  ArrayRef<OpFoldResult> strides,
2779  ArrayRef<NamedAttribute> attrs) {
2780  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2781  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2782  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2783  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2784  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2785  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2786  // Structuring implementation this way avoids duplication between builders.
2787  if (!resultType) {
2788  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2789  staticSizes, staticStrides);
2790  }
2791  result.addAttributes(attrs);
2792  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2793  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2794  b.getDenseI64ArrayAttr(staticSizes),
2795  b.getDenseI64ArrayAttr(staticStrides));
2796 }
2797 
2798 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2799 // type.
2800 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2801  ArrayRef<OpFoldResult> offsets,
2802  ArrayRef<OpFoldResult> sizes,
2803  ArrayRef<OpFoldResult> strides,
2804  ArrayRef<NamedAttribute> attrs) {
2805  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2806 }
2807 
2808 // Build a SubViewOp with static entries and inferred result type.
2809 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2810  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2811  ArrayRef<int64_t> strides,
2812  ArrayRef<NamedAttribute> attrs) {
2813  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2814  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2815  return b.getI64IntegerAttr(v);
2816  }));
2817  SmallVector<OpFoldResult> sizeValues =
2818  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2819  return b.getI64IntegerAttr(v);
2820  }));
2821  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2822  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2823  return b.getI64IntegerAttr(v);
2824  }));
2825  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2826 }
2827 
2828 // Build a SubViewOp with dynamic entries and custom result type. If the
2829 // type passed is nullptr, it is inferred.
2830 void SubViewOp::build(OpBuilder &b, OperationState &result,
2831  MemRefType resultType, Value source,
2832  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2833  ArrayRef<int64_t> strides,
2834  ArrayRef<NamedAttribute> attrs) {
2835  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2836  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2837  return b.getI64IntegerAttr(v);
2838  }));
2839  SmallVector<OpFoldResult> sizeValues =
2840  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2841  return b.getI64IntegerAttr(v);
2842  }));
2843  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2844  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2845  return b.getI64IntegerAttr(v);
2846  }));
2847  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2848  attrs);
2849 }
2850 
2851 // Build a SubViewOp with dynamic entries and custom result type. If the type
2852 // passed is nullptr, it is inferred.
2853 void SubViewOp::build(OpBuilder &b, OperationState &result,
2854  MemRefType resultType, Value source, ValueRange offsets,
2855  ValueRange sizes, ValueRange strides,
2856  ArrayRef<NamedAttribute> attrs) {
2857  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2858  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2859  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2860  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2861  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2862  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2863  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2864 }
2865 
2866 // Build a SubViewOp with dynamic entries and inferred result type.
2867 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2868  ValueRange offsets, ValueRange sizes, ValueRange strides,
2869  ArrayRef<NamedAttribute> attrs) {
2870  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2871 }
2872 
2873 /// For ViewLikeOpInterface.
2874 Value SubViewOp::getViewSource() { return getSource(); }
2875 
2876 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2877 /// static value).
2878 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2879  int64_t t1Offset, t2Offset;
2880  SmallVector<int64_t> t1Strides, t2Strides;
2881  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2882  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2883  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2884 }
2885 
2886 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2887 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2888 /// marked as dropped in `droppedDims`.
2889 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2890  const llvm::SmallBitVector &droppedDims) {
2891  assert(size_t(t1.getRank()) == droppedDims.size() &&
2892  "incorrect number of bits");
2893  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2894  "incorrect number of dropped dims");
2895  int64_t t1Offset, t2Offset;
2896  SmallVector<int64_t> t1Strides, t2Strides;
2897  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2898  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2899  if (failed(res1) || failed(res2))
2900  return false;
2901  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2902  if (droppedDims[i])
2903  continue;
2904  if (t1Strides[i] != t2Strides[j])
2905  return false;
2906  ++j;
2907  }
2908  return true;
2909 }
2910 
2912  Operation *op, Type expectedType) {
2913  auto memrefType = llvm::cast<ShapedType>(expectedType);
2914  switch (result) {
2916  return success();
2918  return op->emitError("expected result rank to be smaller or equal to ")
2919  << "the source rank. ";
2921  return op->emitError("expected result type to be ")
2922  << expectedType
2923  << " or a rank-reduced version. (mismatch of result sizes) ";
2925  return op->emitError("expected result element type to be ")
2926  << memrefType.getElementType();
2928  return op->emitError("expected result and source memory spaces to match.");
2930  return op->emitError("expected result type to be ")
2931  << expectedType
2932  << " or a rank-reduced version. (mismatch of result layout) ";
2933  }
2934  llvm_unreachable("unexpected subview verification result");
2935 }
2936 
2937 /// Verifier for SubViewOp.
2938 LogicalResult SubViewOp::verify() {
2939  MemRefType baseType = getSourceType();
2940  MemRefType subViewType = getType();
2941  ArrayRef<int64_t> staticOffsets = getStaticOffsets();
2942  ArrayRef<int64_t> staticSizes = getStaticSizes();
2943  ArrayRef<int64_t> staticStrides = getStaticStrides();
2944 
2945  // The base memref and the view memref should be in the same memory space.
2946  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2947  return emitError("different memory spaces specified for base memref "
2948  "type ")
2949  << baseType << " and subview memref type " << subViewType;
2950 
2951  // Verify that the base memref type has a strided layout map.
2952  if (!baseType.isStrided())
2953  return emitError("base type ") << baseType << " is not strided";
2954 
2955  // Compute the expected result type, assuming that there are no rank
2956  // reductions.
2957  MemRefType expectedType = SubViewOp::inferResultType(
2958  baseType, staticOffsets, staticSizes, staticStrides);
2959 
2960  // Verify all properties of a shaped type: rank, element type and dimension
2961  // sizes. This takes into account potential rank reductions.
2962  auto shapedTypeVerification = isRankReducedType(
2963  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2964  if (shapedTypeVerification != SliceVerificationResult::Success)
2965  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2966 
2967  // Make sure that the memory space did not change.
2968  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2970  *this, expectedType);
2971 
2972  // Verify the offset of the layout map.
2973  if (!haveCompatibleOffsets(expectedType, subViewType))
2975  *this, expectedType);
2976 
2977  // The only thing that's left to verify now are the strides. First, compute
2978  // the unused dimensions due to rank reductions. We have to look at sizes and
2979  // strides to decide which dimensions were dropped. This function also
2980  // partially verifies strides in case of rank reductions.
2981  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
2982  getMixedSizes());
2983  if (failed(unusedDims))
2985  *this, expectedType);
2986 
2987  // Strides must match.
2988  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
2990  *this, expectedType);
2991 
2992  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2993  // to the base memref.
2994  SliceBoundsVerificationResult boundsResult =
2995  verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
2996  staticStrides, /*generateErrorMessage=*/true);
2997  if (!boundsResult.isValid)
2998  return getOperation()->emitError(boundsResult.errorMessage);
2999 
3000  return success();
3001 }
3002 
3003 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
3004  return os << "range " << range.offset << ":" << range.size << ":"
3005  << range.stride;
3006 }
3007 
3008 /// Return the list of Range (i.e. offset, size, stride). Each Range
3009 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3010 /// with `b` at location `loc`.
3011 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3012  OpBuilder &b, Location loc) {
3013  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3014  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3015  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3017  unsigned rank = ranks[0];
3018  res.reserve(rank);
3019  for (unsigned idx = 0; idx < rank; ++idx) {
3020  Value offset =
3021  op.isDynamicOffset(idx)
3022  ? op.getDynamicOffset(idx)
3023  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
3024  Value size =
3025  op.isDynamicSize(idx)
3026  ? op.getDynamicSize(idx)
3027  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
3028  Value stride =
3029  op.isDynamicStride(idx)
3030  ? op.getDynamicStride(idx)
3031  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
3032  res.emplace_back(Range{offset, size, stride});
3033  }
3034  return res;
3035 }
3036 
3037 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3038 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3039 /// the rank of the inferred result type if `currentResultType` is lower rank
3040 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3041 /// together with the result type. In this case, it is important to compute
3042 /// the dropped dimensions using `currentSourceType` whose strides align with
3043 /// `currentResultType`.
3045  MemRefType currentResultType, MemRefType currentSourceType,
3046  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3047  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3048  MemRefType nonRankReducedType = SubViewOp::inferResultType(
3049  sourceType, mixedOffsets, mixedSizes, mixedStrides);
3050  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3051  currentSourceType, currentResultType, mixedSizes);
3052  if (failed(unusedDims))
3053  return nullptr;
3054 
3055  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3056  SmallVector<int64_t> shape, strides;
3057  unsigned numDimsAfterReduction =
3058  nonRankReducedType.getRank() - unusedDims->count();
3059  shape.reserve(numDimsAfterReduction);
3060  strides.reserve(numDimsAfterReduction);
3061  for (const auto &[idx, size, stride] :
3062  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3063  nonRankReducedType.getShape(), layout.getStrides())) {
3064  if (unusedDims->test(idx))
3065  continue;
3066  shape.push_back(size);
3067  strides.push_back(stride);
3068  }
3069 
3070  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3071  StridedLayoutAttr::get(sourceType.getContext(),
3072  layout.getOffset(), strides),
3073  nonRankReducedType.getMemorySpace());
3074 }
3075 
3077  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3078  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3079  unsigned rank = memrefType.getRank();
3080  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3081  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3082  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3083  MemRefType targetType = SubViewOp::inferRankReducedResultType(
3084  targetShape, memrefType, offsets, sizes, strides);
3085  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3086  sizes, strides);
3087 }
3088 
3089 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3090  Value value,
3091  ArrayRef<int64_t> desiredShape) {
3092  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3093  assert(sourceMemrefType && "not a ranked memref type");
3094  auto sourceShape = sourceMemrefType.getShape();
3095  if (sourceShape.equals(desiredShape))
3096  return value;
3097  auto maybeRankReductionMask =
3098  mlir::computeRankReductionMask(sourceShape, desiredShape);
3099  if (!maybeRankReductionMask)
3100  return failure();
3101  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3102 }
3103 
3104 /// Helper method to check if a `subview` operation is trivially a no-op. This
3105 /// is the case if the all offsets are zero, all strides are 1, and the source
3106 /// shape is same as the size of the subview. In such cases, the subview can
3107 /// be folded into its source.
3108 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3109  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3110  return false;
3111 
3112  auto mixedOffsets = subViewOp.getMixedOffsets();
3113  auto mixedSizes = subViewOp.getMixedSizes();
3114  auto mixedStrides = subViewOp.getMixedStrides();
3115 
3116  // Check offsets are zero.
3117  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3118  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3119  return !intValue || intValue.value() != 0;
3120  }))
3121  return false;
3122 
3123  // Check strides are one.
3124  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3125  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3126  return !intValue || intValue.value() != 1;
3127  }))
3128  return false;
3129 
3130  // Check all size values are static and matches the (static) source shape.
3131  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3132  for (const auto &size : llvm::enumerate(mixedSizes)) {
3133  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3134  if (!intValue || *intValue != sourceShape[size.index()])
3135  return false;
3136  }
3137  // All conditions met. The `SubViewOp` is foldable as a no-op.
3138  return true;
3139 }
3140 
3141 namespace {
3142 /// Pattern to rewrite a subview op with MemRefCast arguments.
3143 /// This essentially pushes memref.cast past its consuming subview when
3144 /// `canFoldIntoConsumerOp` is true.
3145 ///
3146 /// Example:
3147 /// ```
3148 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3149 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3150 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3151 /// ```
3152 /// is rewritten into:
3153 /// ```
3154 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3155 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3156 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3157 /// ```
3158 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3159 public:
3161 
3162  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3163  PatternRewriter &rewriter) const override {
3164  // Any constant operand, just return to let SubViewOpConstantFolder kick
3165  // in.
3166  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3167  return matchPattern(operand, matchConstantIndex());
3168  }))
3169  return failure();
3170 
3171  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3172  if (!castOp)
3173  return failure();
3174 
3175  if (!CastOp::canFoldIntoConsumerOp(castOp))
3176  return failure();
3177 
3178  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3179  // the MemRefCastOp source operand type to infer the result type and the
3180  // current SubViewOp source operand type to compute the dropped dimensions
3181  // if the operation is rank-reducing.
3182  auto resultType = getCanonicalSubViewResultType(
3183  subViewOp.getType(), subViewOp.getSourceType(),
3184  llvm::cast<MemRefType>(castOp.getSource().getType()),
3185  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3186  subViewOp.getMixedStrides());
3187  if (!resultType)
3188  return failure();
3189 
3190  Value newSubView = rewriter.create<SubViewOp>(
3191  subViewOp.getLoc(), resultType, castOp.getSource(),
3192  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3193  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3194  subViewOp.getStaticStrides());
3195  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3196  newSubView);
3197  return success();
3198  }
3199 };
3200 
3201 /// Canonicalize subview ops that are no-ops. When the source shape is not
3202 /// same as a result shape due to use of `affine_map`.
3203 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3204 public:
3206 
3207  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3208  PatternRewriter &rewriter) const override {
3209  if (!isTrivialSubViewOp(subViewOp))
3210  return failure();
3211  if (subViewOp.getSourceType() == subViewOp.getType()) {
3212  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3213  return success();
3214  }
3215  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3216  subViewOp.getSource());
3217  return success();
3218  }
3219 };
3220 } // namespace
3221 
3222 /// Return the canonical type of the result of a subview.
3224  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3225  ArrayRef<OpFoldResult> mixedSizes,
3226  ArrayRef<OpFoldResult> mixedStrides) {
3227  // Infer a memref type without taking into account any rank reductions.
3228  MemRefType resTy = SubViewOp::inferResultType(
3229  op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3230  if (!resTy)
3231  return {};
3232  MemRefType nonReducedType = resTy;
3233 
3234  // Directly return the non-rank reduced type if there are no dropped dims.
3235  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3236  if (droppedDims.none())
3237  return nonReducedType;
3238 
3239  // Take the strides and offset from the non-rank reduced type.
3240  auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3241 
3242  // Drop dims from shape and strides.
3243  SmallVector<int64_t> targetShape;
3244  SmallVector<int64_t> targetStrides;
3245  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3246  if (droppedDims.test(i))
3247  continue;
3248  targetStrides.push_back(nonReducedStrides[i]);
3249  targetShape.push_back(nonReducedType.getDimSize(i));
3250  }
3251 
3252  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3253  StridedLayoutAttr::get(nonReducedType.getContext(),
3254  offset, targetStrides),
3255  nonReducedType.getMemorySpace());
3256  }
3257 };
3258 
3259 /// A canonicalizer wrapper to replace SubViewOps.
3261  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3262  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3263  }
3264 };
3265 
3266 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3267  MLIRContext *context) {
3268  results
3271  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3272 }
3273 
3274 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3275  MemRefType sourceMemrefType = getSource().getType();
3276  MemRefType resultMemrefType = getResult().getType();
3277  auto resultLayout =
3278  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3279 
3280  if (resultMemrefType == sourceMemrefType &&
3281  resultMemrefType.hasStaticShape() &&
3282  (!resultLayout || resultLayout.hasStaticLayout())) {
3283  return getViewSource();
3284  }
3285 
3286  // Fold subview(subview(x)), where both subviews have the same size and the
3287  // second subview's offsets are all zero. (I.e., the second subview is a
3288  // no-op.)
3289  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3290  auto srcSizes = srcSubview.getMixedSizes();
3291  auto sizes = getMixedSizes();
3292  auto offsets = getMixedOffsets();
3293  bool allOffsetsZero = llvm::all_of(
3294  offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3295  auto strides = getMixedStrides();
3296  bool allStridesOne = llvm::all_of(
3297  strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3298  bool allSizesSame = llvm::equal(sizes, srcSizes);
3299  if (allOffsetsZero && allStridesOne && allSizesSame &&
3300  resultMemrefType == sourceMemrefType)
3301  return getViewSource();
3302  }
3303 
3304  return {};
3305 }
3306 
3307 //===----------------------------------------------------------------------===//
3308 // TransposeOp
3309 //===----------------------------------------------------------------------===//
3310 
3311 void TransposeOp::getAsmResultNames(
3312  function_ref<void(Value, StringRef)> setNameFn) {
3313  setNameFn(getResult(), "transpose");
3314 }
3315 
3316 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3317 static MemRefType inferTransposeResultType(MemRefType memRefType,
3318  AffineMap permutationMap) {
3319  auto originalSizes = memRefType.getShape();
3320  auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3321  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3322 
3323  // Compute permuted sizes and strides.
3324  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3325  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3326 
3327  return MemRefType::Builder(memRefType)
3328  .setShape(sizes)
3329  .setLayout(
3330  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3331 }
3332 
3333 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3334  AffineMapAttr permutation,
3335  ArrayRef<NamedAttribute> attrs) {
3336  auto permutationMap = permutation.getValue();
3337  assert(permutationMap);
3338 
3339  auto memRefType = llvm::cast<MemRefType>(in.getType());
3340  // Compute result type.
3341  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3342 
3343  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3344  build(b, result, resultType, in, attrs);
3345 }
3346 
3347 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3349  p << " " << getIn() << " " << getPermutation();
3350  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3351  p << " : " << getIn().getType() << " to " << getType();
3352 }
3353 
3354 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3356  AffineMap permutation;
3357  MemRefType srcType, dstType;
3358  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3359  parser.parseOptionalAttrDict(result.attributes) ||
3360  parser.parseColonType(srcType) ||
3361  parser.resolveOperand(in, srcType, result.operands) ||
3362  parser.parseKeywordType("to", dstType) ||
3363  parser.addTypeToList(dstType, result.types))
3364  return failure();
3365 
3366  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3367  AffineMapAttr::get(permutation));
3368  return success();
3369 }
3370 
3371 LogicalResult TransposeOp::verify() {
3372  if (!getPermutation().isPermutation())
3373  return emitOpError("expected a permutation map");
3374  if (getPermutation().getNumDims() != getIn().getType().getRank())
3375  return emitOpError("expected a permutation map of same rank as the input");
3376 
3377  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3378  auto resultType = llvm::cast<MemRefType>(getType());
3379  auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3380  .canonicalizeStridedLayout();
3381 
3382  if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3383  return emitOpError("result type ")
3384  << resultType
3385  << " is not equivalent to the canonical transposed input type "
3386  << canonicalResultType;
3387  return success();
3388 }
3389 
3390 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3391  // First check for identity permutation, we can fold it away if input and
3392  // result types are identical already.
3393  if (getPermutation().isIdentity() && getType() == getIn().getType())
3394  return getIn();
3395  // Fold two consecutive memref.transpose Ops into one by composing their
3396  // permutation maps.
3397  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3398  AffineMap composedPermutation =
3399  getPermutation().compose(otherTransposeOp.getPermutation());
3400  getInMutable().assign(otherTransposeOp.getIn());
3401  setPermutation(composedPermutation);
3402  return getResult();
3403  }
3404  return {};
3405 }
3406 
3407 //===----------------------------------------------------------------------===//
3408 // ViewOp
3409 //===----------------------------------------------------------------------===//
3410 
3411 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3412  setNameFn(getResult(), "view");
3413 }
3414 
3415 LogicalResult ViewOp::verify() {
3416  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3417  auto viewType = getType();
3418 
3419  // The base memref should have identity layout map (or none).
3420  if (!baseType.getLayout().isIdentity())
3421  return emitError("unsupported map for base memref type ") << baseType;
3422 
3423  // The result memref should have identity layout map (or none).
3424  if (!viewType.getLayout().isIdentity())
3425  return emitError("unsupported map for result memref type ") << viewType;
3426 
3427  // The base memref and the view memref should be in the same memory space.
3428  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3429  return emitError("different memory spaces specified for base memref "
3430  "type ")
3431  << baseType << " and view memref type " << viewType;
3432 
3433  // Verify that we have the correct number of sizes for the result type.
3434  unsigned numDynamicDims = viewType.getNumDynamicDims();
3435  if (getSizes().size() != numDynamicDims)
3436  return emitError("incorrect number of size operands for type ") << viewType;
3437 
3438  return success();
3439 }
3440 
3441 Value ViewOp::getViewSource() { return getSource(); }
3442 
3443 namespace {
3444 
3445 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3447 
3448  LogicalResult matchAndRewrite(ViewOp viewOp,
3449  PatternRewriter &rewriter) const override {
3450  // Return if none of the operands are constants.
3451  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3452  return matchPattern(operand, matchConstantIndex());
3453  }))
3454  return failure();
3455 
3456  // Get result memref type.
3457  auto memrefType = viewOp.getType();
3458 
3459  // Get offset from old memref view type 'memRefType'.
3460  int64_t oldOffset;
3461  SmallVector<int64_t, 4> oldStrides;
3462  if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3463  return failure();
3464  assert(oldOffset == 0 && "Expected 0 offset");
3465 
3466  SmallVector<Value, 4> newOperands;
3467 
3468  // Offset cannot be folded into result type.
3469 
3470  // Fold any dynamic dim operands which are produced by a constant.
3471  SmallVector<int64_t, 4> newShapeConstants;
3472  newShapeConstants.reserve(memrefType.getRank());
3473 
3474  unsigned dynamicDimPos = 0;
3475  unsigned rank = memrefType.getRank();
3476  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3477  int64_t dimSize = memrefType.getDimSize(dim);
3478  // If this is already static dimension, keep it.
3479  if (!ShapedType::isDynamic(dimSize)) {
3480  newShapeConstants.push_back(dimSize);
3481  continue;
3482  }
3483  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3484  if (auto constantIndexOp =
3485  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3486  // Dynamic shape dimension will be folded.
3487  newShapeConstants.push_back(constantIndexOp.value());
3488  } else {
3489  // Dynamic shape dimension not folded; copy operand from old memref.
3490  newShapeConstants.push_back(dimSize);
3491  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3492  }
3493  dynamicDimPos++;
3494  }
3495 
3496  // Create new memref type with constant folded dims.
3497  MemRefType newMemRefType =
3498  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3499  // Nothing new, don't fold.
3500  if (newMemRefType == memrefType)
3501  return failure();
3502 
3503  // Create new ViewOp.
3504  auto newViewOp = rewriter.create<ViewOp>(
3505  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3506  viewOp.getByteShift(), newOperands);
3507  // Insert a cast so we have the same type as the old memref type.
3508  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3509  return success();
3510  }
3511 };
3512 
3513 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3515 
3516  LogicalResult matchAndRewrite(ViewOp viewOp,
3517  PatternRewriter &rewriter) const override {
3518  Value memrefOperand = viewOp.getOperand(0);
3519  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3520  if (!memrefCastOp)
3521  return failure();
3522  Value allocOperand = memrefCastOp.getOperand();
3523  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3524  if (!allocOp)
3525  return failure();
3526  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3527  viewOp.getByteShift(),
3528  viewOp.getSizes());
3529  return success();
3530  }
3531 };
3532 
3533 } // namespace
3534 
3535 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3536  MLIRContext *context) {
3537  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3538 }
3539 
3540 //===----------------------------------------------------------------------===//
3541 // AtomicRMWOp
3542 //===----------------------------------------------------------------------===//
3543 
3544 LogicalResult AtomicRMWOp::verify() {
3545  if (getMemRefType().getRank() != getNumOperands() - 2)
3546  return emitOpError(
3547  "expects the number of subscripts to be equal to memref rank");
3548  switch (getKind()) {
3549  case arith::AtomicRMWKind::addf:
3550  case arith::AtomicRMWKind::maximumf:
3551  case arith::AtomicRMWKind::minimumf:
3552  case arith::AtomicRMWKind::mulf:
3553  if (!llvm::isa<FloatType>(getValue().getType()))
3554  return emitOpError() << "with kind '"
3555  << arith::stringifyAtomicRMWKind(getKind())
3556  << "' expects a floating-point type";
3557  break;
3558  case arith::AtomicRMWKind::addi:
3559  case arith::AtomicRMWKind::maxs:
3560  case arith::AtomicRMWKind::maxu:
3561  case arith::AtomicRMWKind::mins:
3562  case arith::AtomicRMWKind::minu:
3563  case arith::AtomicRMWKind::muli:
3564  case arith::AtomicRMWKind::ori:
3565  case arith::AtomicRMWKind::andi:
3566  if (!llvm::isa<IntegerType>(getValue().getType()))
3567  return emitOpError() << "with kind '"
3568  << arith::stringifyAtomicRMWKind(getKind())
3569  << "' expects an integer type";
3570  break;
3571  default:
3572  break;
3573  }
3574  return success();
3575 }
3576 
3577 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3578  /// atomicrmw(memrefcast) -> atomicrmw
3579  if (succeeded(foldMemRefCast(*this, getValue())))
3580  return getResult();
3581  return OpFoldResult();
3582 }
3583 
3584 //===----------------------------------------------------------------------===//
3585 // TableGen'd op method definitions
3586 //===----------------------------------------------------------------------===//
3587 
3588 #define GET_OP_CLASSES
3589 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition: MemRefOps.cpp:96
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1499
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2097
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:377
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3317
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:358
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2878
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1337
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3044
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
Definition: MemRefOps.cpp:2911
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1513
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:877
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3108
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:400
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2889
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:862
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2190
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2385
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:129
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:140
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:166
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:187
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:177
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:409
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:227
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:68
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3076
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:340
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3011
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:447
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:450
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:407
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:410
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2553
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3260
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3261
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3223
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3224
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.