MLIR  21.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::memref;
29 
30 /// Materialize a single constant operation from a given attribute value with
31 /// the desired resultant type.
33  Attribute value, Type type,
34  Location loc) {
35  return arith::ConstantOp::materialize(builder, value, type, loc);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Common canonicalization pattern support logic
40 //===----------------------------------------------------------------------===//
41 
42 /// This is a common class used for patterns of the form
43 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
44 /// into the root operation directly.
45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46  bool folded = false;
47  for (OpOperand &operand : op->getOpOperands()) {
48  auto cast = operand.get().getDefiningOp<CastOp>();
49  if (cast && operand.get() != inner &&
50  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51  operand.set(cast.getOperand());
52  folded = true;
53  }
54  }
55  return success(folded);
56 }
57 
58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
59 /// type.
61  if (auto memref = llvm::dyn_cast<MemRefType>(type))
62  return RankedTensorType::get(memref.getShape(), memref.getElementType());
63  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
64  return UnrankedTensorType::get(memref.getElementType());
65  return NoneType::get(type.getContext());
66 }
67 
69  int64_t dim) {
70  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that sets values[i] to constValues[i] if the latter is a
91 /// static value, as indicated by ShapedType::kDynamic.
92 ///
93 /// If constValues[i] is dynamic, tries to extract a constant value from
94 /// value[i] to allow for additional folding opportunities. Also convertes all
95 /// existing attributes to index attributes. (They may be i64 attributes.)
97  ArrayRef<int64_t> constValues) {
98  assert(constValues.size() == values.size() &&
99  "incorrect number of const values");
100  for (auto [i, cstVal] : llvm::enumerate(constValues)) {
101  Builder builder(values[i].getContext());
102  if (!ShapedType::isDynamic(cstVal)) {
103  // Constant value is known, use it directly.
104  values[i] = builder.getIndexAttr(cstVal);
105  continue;
106  }
107  if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
108  // Try to extract a constant or convert an existing to index.
109  values[i] = builder.getIndexAttr(*cst);
110  }
111  }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // AllocOp / AllocaOp
116 //===----------------------------------------------------------------------===//
117 
118 void AllocOp::getAsmResultNames(
119  function_ref<void(Value, StringRef)> setNameFn) {
120  setNameFn(getResult(), "alloc");
121 }
122 
123 void AllocaOp::getAsmResultNames(
124  function_ref<void(Value, StringRef)> setNameFn) {
125  setNameFn(getResult(), "alloca");
126 }
127 
128 template <typename AllocLikeOp>
129 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
130  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
131  "applies to only alloc or alloca");
132  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
133  if (!memRefType)
134  return op.emitOpError("result must be a memref");
135 
136  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
137  return op.emitOpError("dimension operand count does not equal memref "
138  "dynamic dimension count");
139 
140  unsigned numSymbols = 0;
141  if (!memRefType.getLayout().isIdentity())
142  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
143  if (op.getSymbolOperands().size() != numSymbols)
144  return op.emitOpError("symbol operand count does not equal memref symbol "
145  "count: expected ")
146  << numSymbols << ", got " << op.getSymbolOperands().size();
147 
148  return success();
149 }
150 
151 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
152 
153 LogicalResult AllocaOp::verify() {
154  // An alloca op needs to have an ancestor with an allocation scope trait.
155  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
156  return emitOpError(
157  "requires an ancestor op with AutomaticAllocationScope trait");
158 
159  return verifyAllocLikeOp(*this);
160 }
161 
162 namespace {
163 /// Fold constant dimensions into an alloc like operation.
164 template <typename AllocLikeOp>
165 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
167 
168  LogicalResult matchAndRewrite(AllocLikeOp alloc,
169  PatternRewriter &rewriter) const override {
170  // Check to see if any dimensions operands are constants. If so, we can
171  // substitute and drop them.
172  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
173  APInt constSizeArg;
174  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
175  return false;
176  return constSizeArg.isNonNegative();
177  }))
178  return failure();
179 
180  auto memrefType = alloc.getType();
181 
182  // Ok, we have one or more constant operands. Collect the non-constant ones
183  // and keep track of the resultant memref type to build.
184  SmallVector<int64_t, 4> newShapeConstants;
185  newShapeConstants.reserve(memrefType.getRank());
186  SmallVector<Value, 4> dynamicSizes;
187 
188  unsigned dynamicDimPos = 0;
189  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
190  int64_t dimSize = memrefType.getDimSize(dim);
191  // If this is already static dimension, keep it.
192  if (!ShapedType::isDynamic(dimSize)) {
193  newShapeConstants.push_back(dimSize);
194  continue;
195  }
196  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
197  APInt constSizeArg;
198  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
199  constSizeArg.isNonNegative()) {
200  // Dynamic shape dimension will be folded.
201  newShapeConstants.push_back(constSizeArg.getZExtValue());
202  } else {
203  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
204  newShapeConstants.push_back(ShapedType::kDynamic);
205  dynamicSizes.push_back(dynamicSize);
206  }
207  dynamicDimPos++;
208  }
209 
210  // Create new memref type (which will have fewer dynamic dimensions).
211  MemRefType newMemRefType =
212  MemRefType::Builder(memrefType).setShape(newShapeConstants);
213  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
214 
215  // Create and insert the alloc op for the new memref.
216  auto newAlloc = rewriter.create<AllocLikeOp>(
217  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
218  alloc.getAlignmentAttr());
219  // Insert a cast so we have the same type as the old alloc.
220  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
221  return success();
222  }
223 };
224 
225 /// Fold alloc operations with no users or only store and dealloc uses.
226 template <typename T>
227 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
229 
230  LogicalResult matchAndRewrite(T alloc,
231  PatternRewriter &rewriter) const override {
232  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
233  if (auto storeOp = dyn_cast<StoreOp>(op))
234  return storeOp.getValue() == alloc;
235  return !isa<DeallocOp>(op);
236  }))
237  return failure();
238 
239  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
240  rewriter.eraseOp(user);
241 
242  rewriter.eraseOp(alloc);
243  return success();
244  }
245 };
246 } // namespace
247 
248 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
249  MLIRContext *context) {
250  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
251 }
252 
253 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
254  MLIRContext *context) {
255  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
256  context);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // ReallocOp
261 //===----------------------------------------------------------------------===//
262 
263 LogicalResult ReallocOp::verify() {
264  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
265  MemRefType resultType = getType();
266 
267  // The source memref should have identity layout (or none).
268  if (!sourceType.getLayout().isIdentity())
269  return emitError("unsupported layout for source memref type ")
270  << sourceType;
271 
272  // The result memref should have identity layout (or none).
273  if (!resultType.getLayout().isIdentity())
274  return emitError("unsupported layout for result memref type ")
275  << resultType;
276 
277  // The source memref and the result memref should be in the same memory space.
278  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
279  return emitError("different memory spaces specified for source memref "
280  "type ")
281  << sourceType << " and result memref type " << resultType;
282 
283  // The source memref and the result memref should have the same element type.
284  if (sourceType.getElementType() != resultType.getElementType())
285  return emitError("different element types specified for source memref "
286  "type ")
287  << sourceType << " and result memref type " << resultType;
288 
289  // Verify that we have the dynamic dimension operand when it is needed.
290  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
291  return emitError("missing dimension operand for result type ")
292  << resultType;
293  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
294  return emitError("unnecessary dimension operand for result type ")
295  << resultType;
296 
297  return success();
298 }
299 
300 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
301  MLIRContext *context) {
302  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // AllocaScopeOp
307 //===----------------------------------------------------------------------===//
308 
310  bool printBlockTerminators = false;
311 
312  p << ' ';
313  if (!getResults().empty()) {
314  p << " -> (" << getResultTypes() << ")";
315  printBlockTerminators = true;
316  }
317  p << ' ';
318  p.printRegion(getBodyRegion(),
319  /*printEntryBlockArgs=*/false,
320  /*printBlockTerminators=*/printBlockTerminators);
321  p.printOptionalAttrDict((*this)->getAttrs());
322 }
323 
324 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
325  // Create a region for the body.
326  result.regions.reserve(1);
327  Region *bodyRegion = result.addRegion();
328 
329  // Parse optional results type list.
330  if (parser.parseOptionalArrowTypeList(result.types))
331  return failure();
332 
333  // Parse the body region.
334  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
335  return failure();
336  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
337  result.location);
338 
339  // Parse the optional attribute list.
340  if (parser.parseOptionalAttrDict(result.attributes))
341  return failure();
342 
343  return success();
344 }
345 
346 void AllocaScopeOp::getSuccessorRegions(
348  if (!point.isParent()) {
349  regions.push_back(RegionSuccessor(getResults()));
350  return;
351  }
352 
353  regions.push_back(RegionSuccessor(&getBodyRegion()));
354 }
355 
356 /// Given an operation, return whether this op is guaranteed to
357 /// allocate an AutomaticAllocationScopeResource
359  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
360  if (!interface)
361  return false;
362  for (auto res : op->getResults()) {
363  if (auto effect =
364  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
365  if (isa<SideEffects::AutomaticAllocationScopeResource>(
366  effect->getResource()))
367  return true;
368  }
369  }
370  return false;
371 }
372 
373 /// Given an operation, return whether this op itself could
374 /// allocate an AutomaticAllocationScopeResource. Note that
375 /// this will not check whether an operation contained within
376 /// the op can allocate.
378  // This op itself doesn't create a stack allocation,
379  // the inner allocation should be handled separately.
381  return false;
382  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
383  if (!interface)
384  return true;
385  for (auto res : op->getResults()) {
386  if (auto effect =
387  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
388  if (isa<SideEffects::AutomaticAllocationScopeResource>(
389  effect->getResource()))
390  return true;
391  }
392  }
393  return false;
394 }
395 
396 /// Return whether this op is the last non terminating op
397 /// in a region. That is to say, it is in a one-block region
398 /// and is only followed by a terminator. This prevents
399 /// extending the lifetime of allocations.
401  return op->getBlock()->mightHaveTerminator() &&
402  op->getNextNode() == op->getBlock()->getTerminator() &&
403  op->getParentRegion()->hasOneBlock();
404 }
405 
406 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
407 /// or it contains no allocation.
408 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
410 
411  LogicalResult matchAndRewrite(AllocaScopeOp op,
412  PatternRewriter &rewriter) const override {
413  bool hasPotentialAlloca =
414  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
415  if (alloc == op)
416  return WalkResult::advance();
418  return WalkResult::interrupt();
419  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
420  return WalkResult::skip();
421  return WalkResult::advance();
422  }).wasInterrupted();
423 
424  // If this contains no potential allocation, it is always legal to
425  // inline. Otherwise, consider two conditions:
426  if (hasPotentialAlloca) {
427  // If the parent isn't an allocation scope, or we are not the last
428  // non-terminator op in the parent, we will extend the lifetime.
429  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
430  return failure();
431  if (!lastNonTerminatorInRegion(op))
432  return failure();
433  }
434 
435  Block *block = &op.getRegion().front();
436  Operation *terminator = block->getTerminator();
437  ValueRange results = terminator->getOperands();
438  rewriter.inlineBlockBefore(block, op);
439  rewriter.replaceOp(op, results);
440  rewriter.eraseOp(terminator);
441  return success();
442  }
443 };
444 
445 /// Move allocations into an allocation scope, if it is legal to
446 /// move them (e.g. their operands are available at the location
447 /// the op would be moved to).
448 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
450 
451  LogicalResult matchAndRewrite(AllocaScopeOp op,
452  PatternRewriter &rewriter) const override {
453 
454  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
455  return failure();
456 
457  Operation *lastParentWithoutScope = op->getParentOp();
458 
459  if (!lastParentWithoutScope ||
460  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
461  return failure();
462 
463  // Only apply to if this is this last non-terminator
464  // op in the block (lest lifetime be extended) of a one
465  // block region
466  if (!lastNonTerminatorInRegion(op) ||
467  !lastNonTerminatorInRegion(lastParentWithoutScope))
468  return failure();
469 
470  while (!lastParentWithoutScope->getParentOp()
472  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
473  if (!lastParentWithoutScope ||
474  !lastNonTerminatorInRegion(lastParentWithoutScope))
475  return failure();
476  }
477  assert(lastParentWithoutScope->getParentOp()
479 
480  Region *containingRegion = nullptr;
481  for (auto &r : lastParentWithoutScope->getRegions()) {
482  if (r.isAncestor(op->getParentRegion())) {
483  assert(containingRegion == nullptr &&
484  "only one region can contain the op");
485  containingRegion = &r;
486  }
487  }
488  assert(containingRegion && "op must be contained in a region");
489 
490  SmallVector<Operation *> toHoist;
491  op->walk([&](Operation *alloc) {
493  return WalkResult::skip();
494 
495  // If any operand is not defined before the location of
496  // lastParentWithoutScope (i.e. where we would hoist to), skip.
497  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
498  return containingRegion->isAncestor(v.getParentRegion());
499  }))
500  return WalkResult::skip();
501  toHoist.push_back(alloc);
502  return WalkResult::advance();
503  });
504 
505  if (toHoist.empty())
506  return failure();
507  rewriter.setInsertionPoint(lastParentWithoutScope);
508  for (auto *op : toHoist) {
509  auto *cloned = rewriter.clone(*op);
510  rewriter.replaceOp(op, cloned->getResults());
511  }
512  return success();
513  }
514 };
515 
516 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
517  MLIRContext *context) {
518  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // AssumeAlignmentOp
523 //===----------------------------------------------------------------------===//
524 
525 LogicalResult AssumeAlignmentOp::verify() {
526  if (!llvm::isPowerOf2_32(getAlignment()))
527  return emitOpError("alignment must be power of 2");
528  return success();
529 }
530 
531 void AssumeAlignmentOp::getAsmResultNames(
532  function_ref<void(Value, StringRef)> setNameFn) {
533  setNameFn(getResult(), "assume_align");
534 }
535 
536 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
537  auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
538  if (!source)
539  return {};
540  if (source.getAlignment() != getAlignment())
541  return {};
542  return getMemref();
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // CastOp
547 //===----------------------------------------------------------------------===//
548 
549 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
550  setNameFn(getResult(), "cast");
551 }
552 
553 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
554 /// source memref. This is useful to fold a memref.cast into a consuming op
555 /// and implement canonicalization patterns for ops in different dialects that
556 /// may consume the results of memref.cast operations. Such foldable memref.cast
557 /// operations are typically inserted as `view` and `subview` ops are
558 /// canonicalized, to preserve the type compatibility of their uses.
559 ///
560 /// Returns true when all conditions are met:
561 /// 1. source and result are ranked memrefs with strided semantics and same
562 /// element type and rank.
563 /// 2. each of the source's size, offset or stride has more static information
564 /// than the corresponding result's size, offset or stride.
565 ///
566 /// Example 1:
567 /// ```mlir
568 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
569 /// %2 = consumer %1 ... : memref<?x?xf32> ...
570 /// ```
571 ///
572 /// may fold into:
573 ///
574 /// ```mlir
575 /// %2 = consumer %0 ... : memref<8x16xf32> ...
576 /// ```
577 ///
578 /// Example 2:
579 /// ```
580 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
581 /// to memref<?x?xf32>
582 /// consumer %1 : memref<?x?xf32> ...
583 /// ```
584 ///
585 /// may fold into:
586 ///
587 /// ```
588 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
589 /// ```
590 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
591  MemRefType sourceType =
592  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
593  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
594 
595  // Requires ranked MemRefType.
596  if (!sourceType || !resultType)
597  return false;
598 
599  // Requires same elemental type.
600  if (sourceType.getElementType() != resultType.getElementType())
601  return false;
602 
603  // Requires same rank.
604  if (sourceType.getRank() != resultType.getRank())
605  return false;
606 
607  // Only fold casts between strided memref forms.
608  int64_t sourceOffset, resultOffset;
609  SmallVector<int64_t, 4> sourceStrides, resultStrides;
610  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
611  failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
612  return false;
613 
614  // If cast is towards more static sizes along any dimension, don't fold.
615  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
616  auto ss = std::get<0>(it), st = std::get<1>(it);
617  if (ss != st)
618  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
619  return false;
620  }
621 
622  // If cast is towards more static offset along any dimension, don't fold.
623  if (sourceOffset != resultOffset)
624  if (ShapedType::isDynamic(sourceOffset) &&
625  !ShapedType::isDynamic(resultOffset))
626  return false;
627 
628  // If cast is towards more static strides along any dimension, don't fold.
629  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
630  auto ss = std::get<0>(it), st = std::get<1>(it);
631  if (ss != st)
632  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
633  return false;
634  }
635 
636  return true;
637 }
638 
639 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
640  if (inputs.size() != 1 || outputs.size() != 1)
641  return false;
642  Type a = inputs.front(), b = outputs.front();
643  auto aT = llvm::dyn_cast<MemRefType>(a);
644  auto bT = llvm::dyn_cast<MemRefType>(b);
645 
646  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
647  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
648 
649  if (aT && bT) {
650  if (aT.getElementType() != bT.getElementType())
651  return false;
652  if (aT.getLayout() != bT.getLayout()) {
653  int64_t aOffset, bOffset;
654  SmallVector<int64_t, 4> aStrides, bStrides;
655  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
656  failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
657  aStrides.size() != bStrides.size())
658  return false;
659 
660  // Strides along a dimension/offset are compatible if the value in the
661  // source memref is static and the value in the target memref is the
662  // same. They are also compatible if either one is dynamic (see
663  // description of MemRefCastOp for details).
664  auto checkCompatible = [](int64_t a, int64_t b) {
665  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
666  };
667  if (!checkCompatible(aOffset, bOffset))
668  return false;
669  for (const auto &aStride : enumerate(aStrides))
670  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
671  return false;
672  }
673  if (aT.getMemorySpace() != bT.getMemorySpace())
674  return false;
675 
676  // They must have the same rank, and any specified dimensions must match.
677  if (aT.getRank() != bT.getRank())
678  return false;
679 
680  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
681  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
682  if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
683  aDim != bDim)
684  return false;
685  }
686  return true;
687  } else {
688  if (!aT && !uaT)
689  return false;
690  if (!bT && !ubT)
691  return false;
692  // Unranked to unranked casting is unsupported
693  if (uaT && ubT)
694  return false;
695 
696  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
697  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
698  if (aEltType != bEltType)
699  return false;
700 
701  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
702  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
703  return aMemSpace == bMemSpace;
704  }
705 
706  return false;
707 }
708 
709 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
710  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // CopyOp
715 //===----------------------------------------------------------------------===//
716 
717 namespace {
718 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
719 /// and element type, the cast can be skipped. Such CastOps only cast the layout
720 /// of the type.
721 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
723 
724  LogicalResult matchAndRewrite(CopyOp copyOp,
725  PatternRewriter &rewriter) const override {
726  bool modified = false;
727 
728  // Check source.
729  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
730  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
731  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
732 
733  if (fromType && toType) {
734  if (fromType.getShape() == toType.getShape() &&
735  fromType.getElementType() == toType.getElementType()) {
736  rewriter.modifyOpInPlace(copyOp, [&] {
737  copyOp.getSourceMutable().assign(castOp.getSource());
738  });
739  modified = true;
740  }
741  }
742  }
743 
744  // Check target.
745  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
746  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
747  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
748 
749  if (fromType && toType) {
750  if (fromType.getShape() == toType.getShape() &&
751  fromType.getElementType() == toType.getElementType()) {
752  rewriter.modifyOpInPlace(copyOp, [&] {
753  copyOp.getTargetMutable().assign(castOp.getSource());
754  });
755  modified = true;
756  }
757  }
758  }
759 
760  return success(modified);
761  }
762 };
763 
764 /// Fold memref.copy(%x, %x).
765 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
767 
768  LogicalResult matchAndRewrite(CopyOp copyOp,
769  PatternRewriter &rewriter) const override {
770  if (copyOp.getSource() != copyOp.getTarget())
771  return failure();
772 
773  rewriter.eraseOp(copyOp);
774  return success();
775  }
776 };
777 
778 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
780 
781  static bool isEmptyMemRef(BaseMemRefType type) {
782  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
783  }
784 
785  LogicalResult matchAndRewrite(CopyOp copyOp,
786  PatternRewriter &rewriter) const override {
787  if (isEmptyMemRef(copyOp.getSource().getType()) ||
788  isEmptyMemRef(copyOp.getTarget().getType())) {
789  rewriter.eraseOp(copyOp);
790  return success();
791  }
792 
793  return failure();
794  }
795 };
796 } // namespace
797 
798 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
799  MLIRContext *context) {
800  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
801 }
802 
803 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
805  /// copy(memrefcast) -> copy
806  bool folded = false;
807  Operation *op = *this;
808  for (OpOperand &operand : op->getOpOperands()) {
809  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
810  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
811  operand.set(castOp.getOperand());
812  folded = true;
813  }
814  }
815  return success(folded);
816 }
817 
818 //===----------------------------------------------------------------------===//
819 // DeallocOp
820 //===----------------------------------------------------------------------===//
821 
822 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
824  /// dealloc(memrefcast) -> dealloc
825  return foldMemRefCast(*this);
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // DimOp
830 //===----------------------------------------------------------------------===//
831 
832 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
833  setNameFn(getResult(), "dim");
834 }
835 
836 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
837  int64_t index) {
838  auto loc = result.location;
839  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
840  build(builder, result, source, indexValue);
841 }
842 
843 std::optional<int64_t> DimOp::getConstantIndex() {
844  return getConstantIntValue(getIndex());
845 }
846 
847 Speculation::Speculatability DimOp::getSpeculatability() {
848  auto constantIndex = getConstantIndex();
849  if (!constantIndex)
851 
852  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
853  if (!rankedSourceType)
855 
856  if (rankedSourceType.getRank() <= constantIndex)
858 
860 }
861 
862 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
863  SetIntLatticeFn setResultRange) {
864  setResultRange(getResult(),
865  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
866 }
867 
868 /// Return a map with key being elements in `vals` and data being number of
869 /// occurences of it. Use std::map, since the `vals` here are strides and the
870 /// dynamic stride value is the same as the tombstone value for
871 /// `DenseMap<int64_t>`.
872 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
873  std::map<int64_t, unsigned> numOccurences;
874  for (auto val : vals)
875  numOccurences[val]++;
876  return numOccurences;
877 }
878 
879 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
880 /// to be a subset of `originalType` with some `1` entries erased, return the
881 /// set of indices that specifies which of the entries of `originalShape` are
882 /// dropped to obtain `reducedShape`.
883 /// This accounts for cases where there are multiple unit-dims, but only a
884 /// subset of those are dropped. For MemRefTypes these can be disambiguated
885 /// using the strides. If a dimension is dropped the stride must be dropped too.
886 static FailureOr<llvm::SmallBitVector>
887 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
888  ArrayRef<OpFoldResult> sizes) {
889  llvm::SmallBitVector unusedDims(originalType.getRank());
890  if (originalType.getRank() == reducedType.getRank())
891  return unusedDims;
892 
893  for (const auto &dim : llvm::enumerate(sizes))
894  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
895  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
896  unusedDims.set(dim.index());
897 
898  // Early exit for the case where the number of unused dims matches the number
899  // of ranks reduced.
900  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
901  originalType.getRank())
902  return unusedDims;
903 
904  SmallVector<int64_t> originalStrides, candidateStrides;
905  int64_t originalOffset, candidateOffset;
906  if (failed(
907  originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
908  failed(
909  reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
910  return failure();
911 
912  // For memrefs, a dimension is truly dropped if its corresponding stride is
913  // also dropped. This is particularly important when more than one of the dims
914  // is 1. Track the number of occurences of the strides in the original type
915  // and the candidate type. For each unused dim that stride should not be
916  // present in the candidate type. Note that there could be multiple dimensions
917  // that have the same size. We dont need to exactly figure out which dim
918  // corresponds to which stride, we just need to verify that the number of
919  // reptitions of a stride in the original + number of unused dims with that
920  // stride == number of repititions of a stride in the candidate.
921  std::map<int64_t, unsigned> currUnaccountedStrides =
922  getNumOccurences(originalStrides);
923  std::map<int64_t, unsigned> candidateStridesNumOccurences =
924  getNumOccurences(candidateStrides);
925  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
926  if (!unusedDims.test(dim))
927  continue;
928  int64_t originalStride = originalStrides[dim];
929  if (currUnaccountedStrides[originalStride] >
930  candidateStridesNumOccurences[originalStride]) {
931  // This dim can be treated as dropped.
932  currUnaccountedStrides[originalStride]--;
933  continue;
934  }
935  if (currUnaccountedStrides[originalStride] ==
936  candidateStridesNumOccurences[originalStride]) {
937  // The stride for this is not dropped. Keep as is.
938  unusedDims.reset(dim);
939  continue;
940  }
941  if (currUnaccountedStrides[originalStride] <
942  candidateStridesNumOccurences[originalStride]) {
943  // This should never happen. Cant have a stride in the reduced rank type
944  // that wasnt in the original one.
945  return failure();
946  }
947  }
948 
949  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
950  originalType.getRank())
951  return failure();
952  return unusedDims;
953 }
954 
955 llvm::SmallBitVector SubViewOp::getDroppedDims() {
956  MemRefType sourceType = getSourceType();
957  MemRefType resultType = getType();
958  FailureOr<llvm::SmallBitVector> unusedDims =
959  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
960  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
961  return *unusedDims;
962 }
963 
964 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
965  // All forms of folding require a known index.
966  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
967  if (!index)
968  return {};
969 
970  // Folding for unranked types (UnrankedMemRefType) is not supported.
971  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
972  if (!memrefType)
973  return {};
974 
975  // Out of bound indices produce undefined behavior but are still valid IR.
976  // Don't choke on them.
977  int64_t indexVal = index.getInt();
978  if (indexVal < 0 || indexVal >= memrefType.getRank())
979  return {};
980 
981  // Fold if the shape extent along the given index is known.
982  if (!memrefType.isDynamicDim(index.getInt())) {
983  Builder builder(getContext());
984  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
985  }
986 
987  // The size at the given index is now known to be a dynamic size.
988  unsigned unsignedIndex = index.getValue().getZExtValue();
989 
990  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
991  Operation *definingOp = getSource().getDefiningOp();
992 
993  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
994  return *(alloc.getDynamicSizes().begin() +
995  memrefType.getDynamicDimIndex(unsignedIndex));
996 
997  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
998  return *(alloca.getDynamicSizes().begin() +
999  memrefType.getDynamicDimIndex(unsignedIndex));
1000 
1001  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1002  return *(view.getDynamicSizes().begin() +
1003  memrefType.getDynamicDimIndex(unsignedIndex));
1004 
1005  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1006  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1007  unsigned resultIndex = 0;
1008  unsigned sourceRank = subview.getSourceType().getRank();
1009  unsigned sourceIndex = 0;
1010  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1011  if (unusedDims.test(i))
1012  continue;
1013  if (resultIndex == unsignedIndex) {
1014  sourceIndex = i;
1015  break;
1016  }
1017  resultIndex++;
1018  }
1019  assert(subview.isDynamicSize(sourceIndex) &&
1020  "expected dynamic subview size");
1021  return subview.getDynamicSize(sourceIndex);
1022  }
1023 
1024  if (auto sizeInterface =
1025  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1026  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1027  "Expected dynamic subview size");
1028  return sizeInterface.getDynamicSize(unsignedIndex);
1029  }
1030 
1031  // dim(memrefcast) -> dim
1032  if (succeeded(foldMemRefCast(*this)))
1033  return getResult();
1034 
1035  return {};
1036 }
1037 
1038 namespace {
1039 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1040 /// operand.
1041 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1043 
1044  LogicalResult matchAndRewrite(DimOp dim,
1045  PatternRewriter &rewriter) const override {
1046  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1047 
1048  if (!reshape)
1049  return rewriter.notifyMatchFailure(
1050  dim, "Dim op is not defined by a reshape op.");
1051 
1052  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1053  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1054  // cheaply check that either of the following conditions hold:
1055  // 1. dim.getIndex() is defined in the same block as reshape but before
1056  // reshape.
1057  // 2. dim.getIndex() is defined in a parent block of
1058  // reshape.
1059 
1060  // Check condition 1
1061  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1062  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1063  if (reshape->isBeforeInBlock(definingOp)) {
1064  return rewriter.notifyMatchFailure(
1065  dim,
1066  "dim.getIndex is not defined before reshape in the same block.");
1067  }
1068  } // else dim.getIndex is a block argument to reshape->getBlock and
1069  // dominates reshape
1070  } // Check condition 2
1071  else if (dim->getBlock() != reshape->getBlock() &&
1072  !dim.getIndex().getParentRegion()->isProperAncestor(
1073  reshape->getParentRegion())) {
1074  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1075  // already know dim.getIndex() dominates reshape without calling
1076  // `isProperAncestor`
1077  return rewriter.notifyMatchFailure(
1078  dim, "dim.getIndex does not dominate reshape.");
1079  }
1080 
1081  // Place the load directly after the reshape to ensure that the shape memref
1082  // was not mutated.
1083  rewriter.setInsertionPointAfter(reshape);
1084  Location loc = dim.getLoc();
1085  Value load =
1086  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1087  if (load.getType() != dim.getType())
1088  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1089  rewriter.replaceOp(dim, load);
1090  return success();
1091  }
1092 };
1093 
1094 } // namespace
1095 
1096 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1097  MLIRContext *context) {
1098  results.add<DimOfMemRefReshape>(context);
1099 }
1100 
1101 // ---------------------------------------------------------------------------
1102 // DmaStartOp
1103 // ---------------------------------------------------------------------------
1104 
1105 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1106  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1107  ValueRange destIndices, Value numElements,
1108  Value tagMemRef, ValueRange tagIndices, Value stride,
1109  Value elementsPerStride) {
1110  result.addOperands(srcMemRef);
1111  result.addOperands(srcIndices);
1112  result.addOperands(destMemRef);
1113  result.addOperands(destIndices);
1114  result.addOperands({numElements, tagMemRef});
1115  result.addOperands(tagIndices);
1116  if (stride)
1117  result.addOperands({stride, elementsPerStride});
1118 }
1119 
1121  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1122  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1123  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1124  if (isStrided())
1125  p << ", " << getStride() << ", " << getNumElementsPerStride();
1126 
1127  p.printOptionalAttrDict((*this)->getAttrs());
1128  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1129  << ", " << getTagMemRef().getType();
1130 }
1131 
1132 // Parse DmaStartOp.
1133 // Ex:
1134 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1135 // %tag[%index], %stride, %num_elt_per_stride :
1136 // : memref<3076 x f32, 0>,
1137 // memref<1024 x f32, 2>,
1138 // memref<1 x i32>
1139 //
1140 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1141  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1143  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1145  OpAsmParser::UnresolvedOperand numElementsInfo;
1146  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1149 
1150  SmallVector<Type, 3> types;
1151  auto indexType = parser.getBuilder().getIndexType();
1152 
1153  // Parse and resolve the following list of operands:
1154  // *) source memref followed by its indices (in square brackets).
1155  // *) destination memref followed by its indices (in square brackets).
1156  // *) dma size in KiB.
1157  if (parser.parseOperand(srcMemRefInfo) ||
1158  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1159  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1160  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1161  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1162  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1163  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1164  return failure();
1165 
1166  // Parse optional stride and elements per stride.
1167  if (parser.parseTrailingOperandList(strideInfo))
1168  return failure();
1169 
1170  bool isStrided = strideInfo.size() == 2;
1171  if (!strideInfo.empty() && !isStrided) {
1172  return parser.emitError(parser.getNameLoc(),
1173  "expected two stride related operands");
1174  }
1175 
1176  if (parser.parseColonTypeList(types))
1177  return failure();
1178  if (types.size() != 3)
1179  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1180 
1181  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1182  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1183  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1184  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1185  // size should be an index.
1186  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1187  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1188  // tag indices should be index.
1189  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1190  return failure();
1191 
1192  if (isStrided) {
1193  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1194  return failure();
1195  }
1196 
1197  return success();
1198 }
1199 
1200 LogicalResult DmaStartOp::verify() {
1201  unsigned numOperands = getNumOperands();
1202 
1203  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1204  // the number of elements.
1205  if (numOperands < 4)
1206  return emitOpError("expected at least 4 operands");
1207 
1208  // Check types of operands. The order of these calls is important: the later
1209  // calls rely on some type properties to compute the operand position.
1210  // 1. Source memref.
1211  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1212  return emitOpError("expected source to be of memref type");
1213  if (numOperands < getSrcMemRefRank() + 4)
1214  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1215  << " operands";
1216  if (!getSrcIndices().empty() &&
1217  !llvm::all_of(getSrcIndices().getTypes(),
1218  [](Type t) { return t.isIndex(); }))
1219  return emitOpError("expected source indices to be of index type");
1220 
1221  // 2. Destination memref.
1222  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1223  return emitOpError("expected destination to be of memref type");
1224  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1225  if (numOperands < numExpectedOperands)
1226  return emitOpError() << "expected at least " << numExpectedOperands
1227  << " operands";
1228  if (!getDstIndices().empty() &&
1229  !llvm::all_of(getDstIndices().getTypes(),
1230  [](Type t) { return t.isIndex(); }))
1231  return emitOpError("expected destination indices to be of index type");
1232 
1233  // 3. Number of elements.
1234  if (!getNumElements().getType().isIndex())
1235  return emitOpError("expected num elements to be of index type");
1236 
1237  // 4. Tag memref.
1238  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1239  return emitOpError("expected tag to be of memref type");
1240  numExpectedOperands += getTagMemRefRank();
1241  if (numOperands < numExpectedOperands)
1242  return emitOpError() << "expected at least " << numExpectedOperands
1243  << " operands";
1244  if (!getTagIndices().empty() &&
1245  !llvm::all_of(getTagIndices().getTypes(),
1246  [](Type t) { return t.isIndex(); }))
1247  return emitOpError("expected tag indices to be of index type");
1248 
1249  // Optional stride-related operands must be either both present or both
1250  // absent.
1251  if (numOperands != numExpectedOperands &&
1252  numOperands != numExpectedOperands + 2)
1253  return emitOpError("incorrect number of operands");
1254 
1255  // 5. Strides.
1256  if (isStrided()) {
1257  if (!getStride().getType().isIndex() ||
1258  !getNumElementsPerStride().getType().isIndex())
1259  return emitOpError(
1260  "expected stride and num elements per stride to be of type index");
1261  }
1262 
1263  return success();
1264 }
1265 
1266 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1267  SmallVectorImpl<OpFoldResult> &results) {
1268  /// dma_start(memrefcast) -> dma_start
1269  return foldMemRefCast(*this);
1270 }
1271 
1272 // ---------------------------------------------------------------------------
1273 // DmaWaitOp
1274 // ---------------------------------------------------------------------------
1275 
1276 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1277  SmallVectorImpl<OpFoldResult> &results) {
1278  /// dma_wait(memrefcast) -> dma_wait
1279  return foldMemRefCast(*this);
1280 }
1281 
1282 LogicalResult DmaWaitOp::verify() {
1283  // Check that the number of tag indices matches the tagMemRef rank.
1284  unsigned numTagIndices = getTagIndices().size();
1285  unsigned tagMemRefRank = getTagMemRefRank();
1286  if (numTagIndices != tagMemRefRank)
1287  return emitOpError() << "expected tagIndices to have the same number of "
1288  "elements as the tagMemRef rank, expected "
1289  << tagMemRefRank << ", but got " << numTagIndices;
1290  return success();
1291 }
1292 
1293 //===----------------------------------------------------------------------===//
1294 // ExtractAlignedPointerAsIndexOp
1295 //===----------------------------------------------------------------------===//
1296 
1297 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1298  function_ref<void(Value, StringRef)> setNameFn) {
1299  setNameFn(getResult(), "intptr");
1300 }
1301 
1302 //===----------------------------------------------------------------------===//
1303 // ExtractStridedMetadataOp
1304 //===----------------------------------------------------------------------===//
1305 
1306 /// The number and type of the results are inferred from the
1307 /// shape of the source.
1308 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1309  MLIRContext *context, std::optional<Location> location,
1310  ExtractStridedMetadataOp::Adaptor adaptor,
1311  SmallVectorImpl<Type> &inferredReturnTypes) {
1312  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1313  if (!sourceType)
1314  return failure();
1315 
1316  unsigned sourceRank = sourceType.getRank();
1317  IndexType indexType = IndexType::get(context);
1318  auto memrefType =
1319  MemRefType::get({}, sourceType.getElementType(),
1320  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1321  // Base.
1322  inferredReturnTypes.push_back(memrefType);
1323  // Offset.
1324  inferredReturnTypes.push_back(indexType);
1325  // Sizes and strides.
1326  for (unsigned i = 0; i < sourceRank * 2; ++i)
1327  inferredReturnTypes.push_back(indexType);
1328  return success();
1329 }
1330 
1331 void ExtractStridedMetadataOp::getAsmResultNames(
1332  function_ref<void(Value, StringRef)> setNameFn) {
1333  setNameFn(getBaseBuffer(), "base_buffer");
1334  setNameFn(getOffset(), "offset");
1335  // For multi-result to work properly with pretty names and packed syntax `x:3`
1336  // we can only give a pretty name to the first value in the pack.
1337  if (!getSizes().empty()) {
1338  setNameFn(getSizes().front(), "sizes");
1339  setNameFn(getStrides().front(), "strides");
1340  }
1341 }
1342 
1343 /// Helper function to perform the replacement of all constant uses of `values`
1344 /// by a materialized constant extracted from `maybeConstants`.
1345 /// `values` and `maybeConstants` are expected to have the same size.
1346 template <typename Container>
1347 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1348  Container values,
1349  ArrayRef<OpFoldResult> maybeConstants) {
1350  assert(values.size() == maybeConstants.size() &&
1351  " expected values and maybeConstants of the same size");
1352  bool atLeastOneReplacement = false;
1353  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1354  // Don't materialize a constant if there are no uses: this would indice
1355  // infinite loops in the driver.
1356  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1357  continue;
1358  assert(isa<Attribute>(maybeConstant) &&
1359  "The constified value should be either unchanged (i.e., == result) "
1360  "or a constant");
1361  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1362  loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1363  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1364  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1365  // yet.
1366  op->replaceUsesOfWith(result, constantVal);
1367  atLeastOneReplacement = true;
1368  }
1369  }
1370  return atLeastOneReplacement;
1371 }
1372 
1373 LogicalResult
1374 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1375  SmallVectorImpl<OpFoldResult> &results) {
1376  OpBuilder builder(*this);
1377 
1378  bool atLeastOneReplacement = replaceConstantUsesOf(
1379  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1380  getConstifiedMixedOffset());
1381  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1382  getConstifiedMixedSizes());
1383  atLeastOneReplacement |= replaceConstantUsesOf(
1384  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1385 
1386  return success(atLeastOneReplacement);
1387 }
1388 
1389 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1390  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1391  constifyIndexValues(values, getSource().getType().getShape());
1392  return values;
1393 }
1394 
1396 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1397  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1398  SmallVector<int64_t> staticValues;
1399  int64_t unused;
1400  LogicalResult status =
1401  getSource().getType().getStridesAndOffset(staticValues, unused);
1402  (void)status;
1403  assert(succeeded(status) && "could not get strides from type");
1404  constifyIndexValues(values, staticValues);
1405  return values;
1406 }
1407 
1408 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1409  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1410  SmallVector<OpFoldResult> values(1, offsetOfr);
1411  SmallVector<int64_t> staticValues, unused;
1412  int64_t offset;
1413  LogicalResult status =
1414  getSource().getType().getStridesAndOffset(unused, offset);
1415  (void)status;
1416  assert(succeeded(status) && "could not get offset from type");
1417  staticValues.push_back(offset);
1418  constifyIndexValues(values, staticValues);
1419  return values[0];
1420 }
1421 
1422 //===----------------------------------------------------------------------===//
1423 // GenericAtomicRMWOp
1424 //===----------------------------------------------------------------------===//
1425 
1426 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1427  Value memref, ValueRange ivs) {
1428  OpBuilder::InsertionGuard g(builder);
1429  result.addOperands(memref);
1430  result.addOperands(ivs);
1431 
1432  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1433  Type elementType = memrefType.getElementType();
1434  result.addTypes(elementType);
1435 
1436  Region *bodyRegion = result.addRegion();
1437  builder.createBlock(bodyRegion);
1438  bodyRegion->addArgument(elementType, memref.getLoc());
1439  }
1440 }
1441 
1442 LogicalResult GenericAtomicRMWOp::verify() {
1443  auto &body = getRegion();
1444  if (body.getNumArguments() != 1)
1445  return emitOpError("expected single number of entry block arguments");
1446 
1447  if (getResult().getType() != body.getArgument(0).getType())
1448  return emitOpError("expected block argument of the same type result type");
1449 
1450  bool hasSideEffects =
1451  body.walk([&](Operation *nestedOp) {
1452  if (isMemoryEffectFree(nestedOp))
1453  return WalkResult::advance();
1454  nestedOp->emitError(
1455  "body of 'memref.generic_atomic_rmw' should contain "
1456  "only operations with no side effects");
1457  return WalkResult::interrupt();
1458  })
1459  .wasInterrupted();
1460  return hasSideEffects ? failure() : success();
1461 }
1462 
1463 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1464  OperationState &result) {
1466  Type memrefType;
1468 
1469  Type indexType = parser.getBuilder().getIndexType();
1470  if (parser.parseOperand(memref) ||
1472  parser.parseColonType(memrefType) ||
1473  parser.resolveOperand(memref, memrefType, result.operands) ||
1474  parser.resolveOperands(ivs, indexType, result.operands))
1475  return failure();
1476 
1477  Region *body = result.addRegion();
1478  if (parser.parseRegion(*body, {}) ||
1479  parser.parseOptionalAttrDict(result.attributes))
1480  return failure();
1481  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1482  return success();
1483 }
1484 
1486  p << ' ' << getMemref() << "[" << getIndices()
1487  << "] : " << getMemref().getType() << ' ';
1488  p.printRegion(getRegion());
1489  p.printOptionalAttrDict((*this)->getAttrs());
1490 }
1491 
1492 //===----------------------------------------------------------------------===//
1493 // AtomicYieldOp
1494 //===----------------------------------------------------------------------===//
1495 
1496 LogicalResult AtomicYieldOp::verify() {
1497  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1498  Type resultType = getResult().getType();
1499  if (parentType != resultType)
1500  return emitOpError() << "types mismatch between yield op: " << resultType
1501  << " and its parent: " << parentType;
1502  return success();
1503 }
1504 
1505 //===----------------------------------------------------------------------===//
1506 // GlobalOp
1507 //===----------------------------------------------------------------------===//
1508 
1510  TypeAttr type,
1511  Attribute initialValue) {
1512  p << type;
1513  if (!op.isExternal()) {
1514  p << " = ";
1515  if (op.isUninitialized())
1516  p << "uninitialized";
1517  else
1518  p.printAttributeWithoutType(initialValue);
1519  }
1520 }
1521 
1522 static ParseResult
1524  Attribute &initialValue) {
1525  Type type;
1526  if (parser.parseType(type))
1527  return failure();
1528 
1529  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1530  if (!memrefType || !memrefType.hasStaticShape())
1531  return parser.emitError(parser.getNameLoc())
1532  << "type should be static shaped memref, but got " << type;
1533  typeAttr = TypeAttr::get(type);
1534 
1535  if (parser.parseOptionalEqual())
1536  return success();
1537 
1538  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1539  initialValue = UnitAttr::get(parser.getContext());
1540  return success();
1541  }
1542 
1543  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1544  if (parser.parseAttribute(initialValue, tensorType))
1545  return failure();
1546  if (!llvm::isa<ElementsAttr>(initialValue))
1547  return parser.emitError(parser.getNameLoc())
1548  << "initial value should be a unit or elements attribute";
1549  return success();
1550 }
1551 
1552 LogicalResult GlobalOp::verify() {
1553  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1554  if (!memrefType || !memrefType.hasStaticShape())
1555  return emitOpError("type should be static shaped memref, but got ")
1556  << getType();
1557 
1558  // Verify that the initial value, if present, is either a unit attribute or
1559  // an elements attribute.
1560  if (getInitialValue().has_value()) {
1561  Attribute initValue = getInitialValue().value();
1562  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1563  return emitOpError("initial value should be a unit or elements "
1564  "attribute, but got ")
1565  << initValue;
1566 
1567  // Check that the type of the initial value is compatible with the type of
1568  // the global variable.
1569  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1570  Type initType = elementsAttr.getType();
1571  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1572  if (initType != tensorType)
1573  return emitOpError("initial value expected to be of type ")
1574  << tensorType << ", but was of type " << initType;
1575  }
1576  }
1577 
1578  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1579  uint64_t alignment = *alignAttr;
1580 
1581  if (!llvm::isPowerOf2_64(alignment))
1582  return emitError() << "alignment attribute value " << alignment
1583  << " is not a power of 2";
1584  }
1585 
1586  // TODO: verify visibility for declarations.
1587  return success();
1588 }
1589 
1590 ElementsAttr GlobalOp::getConstantInitValue() {
1591  auto initVal = getInitialValue();
1592  if (getConstant() && initVal.has_value())
1593  return llvm::cast<ElementsAttr>(initVal.value());
1594  return {};
1595 }
1596 
1597 //===----------------------------------------------------------------------===//
1598 // GetGlobalOp
1599 //===----------------------------------------------------------------------===//
1600 
1601 LogicalResult
1602 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1603  // Verify that the result type is same as the type of the referenced
1604  // memref.global op.
1605  auto global =
1606  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1607  if (!global)
1608  return emitOpError("'")
1609  << getName() << "' does not reference a valid global memref";
1610 
1611  Type resultType = getResult().getType();
1612  if (global.getType() != resultType)
1613  return emitOpError("result type ")
1614  << resultType << " does not match type " << global.getType()
1615  << " of the global memref @" << getName();
1616  return success();
1617 }
1618 
1619 //===----------------------------------------------------------------------===//
1620 // LoadOp
1621 //===----------------------------------------------------------------------===//
1622 
1623 LogicalResult LoadOp::verify() {
1624  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1625  return emitOpError("incorrect number of indices for load, expected ")
1626  << getMemRefType().getRank() << " but got " << getIndices().size();
1627  }
1628  return success();
1629 }
1630 
1631 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1632  /// load(memrefcast) -> load
1633  if (succeeded(foldMemRefCast(*this)))
1634  return getResult();
1635  return OpFoldResult();
1636 }
1637 
1638 //===----------------------------------------------------------------------===//
1639 // MemorySpaceCastOp
1640 //===----------------------------------------------------------------------===//
1641 
1642 void MemorySpaceCastOp::getAsmResultNames(
1643  function_ref<void(Value, StringRef)> setNameFn) {
1644  setNameFn(getResult(), "memspacecast");
1645 }
1646 
1647 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1648  if (inputs.size() != 1 || outputs.size() != 1)
1649  return false;
1650  Type a = inputs.front(), b = outputs.front();
1651  auto aT = llvm::dyn_cast<MemRefType>(a);
1652  auto bT = llvm::dyn_cast<MemRefType>(b);
1653 
1654  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1655  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1656 
1657  if (aT && bT) {
1658  if (aT.getElementType() != bT.getElementType())
1659  return false;
1660  if (aT.getLayout() != bT.getLayout())
1661  return false;
1662  if (aT.getShape() != bT.getShape())
1663  return false;
1664  return true;
1665  }
1666  if (uaT && ubT) {
1667  return uaT.getElementType() == ubT.getElementType();
1668  }
1669  return false;
1670 }
1671 
1672 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1673  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1674  // t2)
1675  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1676  getSourceMutable().assign(parentCast.getSource());
1677  return getResult();
1678  }
1679  return Value{};
1680 }
1681 
1682 //===----------------------------------------------------------------------===//
1683 // PrefetchOp
1684 //===----------------------------------------------------------------------===//
1685 
1687  p << " " << getMemref() << '[';
1689  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1690  p << ", locality<" << getLocalityHint();
1691  p << ">, " << (getIsDataCache() ? "data" : "instr");
1693  (*this)->getAttrs(),
1694  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1695  p << " : " << getMemRefType();
1696 }
1697 
1698 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1699  OpAsmParser::UnresolvedOperand memrefInfo;
1701  IntegerAttr localityHint;
1702  MemRefType type;
1703  StringRef readOrWrite, cacheType;
1704 
1705  auto indexTy = parser.getBuilder().getIndexType();
1706  auto i32Type = parser.getBuilder().getIntegerType(32);
1707  if (parser.parseOperand(memrefInfo) ||
1708  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1709  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1710  parser.parseComma() || parser.parseKeyword("locality") ||
1711  parser.parseLess() ||
1712  parser.parseAttribute(localityHint, i32Type, "localityHint",
1713  result.attributes) ||
1714  parser.parseGreater() || parser.parseComma() ||
1715  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1716  parser.resolveOperand(memrefInfo, type, result.operands) ||
1717  parser.resolveOperands(indexInfo, indexTy, result.operands))
1718  return failure();
1719 
1720  if (readOrWrite != "read" && readOrWrite != "write")
1721  return parser.emitError(parser.getNameLoc(),
1722  "rw specifier has to be 'read' or 'write'");
1723  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1724  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1725 
1726  if (cacheType != "data" && cacheType != "instr")
1727  return parser.emitError(parser.getNameLoc(),
1728  "cache type has to be 'data' or 'instr'");
1729 
1730  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1731  parser.getBuilder().getBoolAttr(cacheType == "data"));
1732 
1733  return success();
1734 }
1735 
1736 LogicalResult PrefetchOp::verify() {
1737  if (getNumOperands() != 1 + getMemRefType().getRank())
1738  return emitOpError("too few indices");
1739 
1740  return success();
1741 }
1742 
1743 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1744  SmallVectorImpl<OpFoldResult> &results) {
1745  // prefetch(memrefcast) -> prefetch
1746  return foldMemRefCast(*this);
1747 }
1748 
1749 //===----------------------------------------------------------------------===//
1750 // RankOp
1751 //===----------------------------------------------------------------------===//
1752 
1753 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1754  // Constant fold rank when the rank of the operand is known.
1755  auto type = getOperand().getType();
1756  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1757  if (shapedType && shapedType.hasRank())
1758  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1759  return IntegerAttr();
1760 }
1761 
1762 //===----------------------------------------------------------------------===//
1763 // ReinterpretCastOp
1764 //===----------------------------------------------------------------------===//
1765 
1766 void ReinterpretCastOp::getAsmResultNames(
1767  function_ref<void(Value, StringRef)> setNameFn) {
1768  setNameFn(getResult(), "reinterpret_cast");
1769 }
1770 
1771 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1772 /// `staticSizes` and `staticStrides` are automatically filled with
1773 /// source-memref-rank sentinel values that encode dynamic entries.
1774 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1775  MemRefType resultType, Value source,
1776  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1777  ArrayRef<OpFoldResult> strides,
1778  ArrayRef<NamedAttribute> attrs) {
1779  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1780  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1781  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1782  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1783  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1784  result.addAttributes(attrs);
1785  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1786  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1787  b.getDenseI64ArrayAttr(staticSizes),
1788  b.getDenseI64ArrayAttr(staticStrides));
1789 }
1790 
1791 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1792  Value source, OpFoldResult offset,
1793  ArrayRef<OpFoldResult> sizes,
1794  ArrayRef<OpFoldResult> strides,
1795  ArrayRef<NamedAttribute> attrs) {
1796  auto sourceType = cast<BaseMemRefType>(source.getType());
1797  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1798  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1799  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1800  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1801  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1802  auto stridedLayout = StridedLayoutAttr::get(
1803  b.getContext(), staticOffsets.front(), staticStrides);
1804  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1805  stridedLayout, sourceType.getMemorySpace());
1806  build(b, result, resultType, source, offset, sizes, strides, attrs);
1807 }
1808 
1809 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1810  MemRefType resultType, Value source,
1811  int64_t offset, ArrayRef<int64_t> sizes,
1812  ArrayRef<int64_t> strides,
1813  ArrayRef<NamedAttribute> attrs) {
1814  SmallVector<OpFoldResult> sizeValues =
1815  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1816  return b.getI64IntegerAttr(v);
1817  }));
1818  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1819  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1820  return b.getI64IntegerAttr(v);
1821  }));
1822  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1823  strideValues, attrs);
1824 }
1825 
1826 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1827  MemRefType resultType, Value source, Value offset,
1828  ValueRange sizes, ValueRange strides,
1829  ArrayRef<NamedAttribute> attrs) {
1830  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1831  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1832  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1833  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1834  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1835 }
1836 
1837 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1838 // completed automatically, like we have for subview and extract_slice.
1839 LogicalResult ReinterpretCastOp::verify() {
1840  // The source and result memrefs should be in the same memory space.
1841  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1842  auto resultType = llvm::cast<MemRefType>(getType());
1843  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1844  return emitError("different memory spaces specified for source type ")
1845  << srcType << " and result memref type " << resultType;
1846  if (srcType.getElementType() != resultType.getElementType())
1847  return emitError("different element types specified for source type ")
1848  << srcType << " and result memref type " << resultType;
1849 
1850  // Match sizes in result memref type and in static_sizes attribute.
1851  for (auto [idx, resultSize, expectedSize] :
1852  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1853  if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
1854  return emitError("expected result type with size = ")
1855  << (ShapedType::isDynamic(expectedSize)
1856  ? std::string("dynamic")
1857  : std::to_string(expectedSize))
1858  << " instead of " << resultSize << " in dim = " << idx;
1859  }
1860 
1861  // Match offset and strides in static_offset and static_strides attributes. If
1862  // result memref type has no affine map specified, this will assume an
1863  // identity layout.
1864  int64_t resultOffset;
1865  SmallVector<int64_t, 4> resultStrides;
1866  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1867  return emitError("expected result type to have strided layout but found ")
1868  << resultType;
1869 
1870  // Match offset in result memref type and in static_offsets attribute.
1871  int64_t expectedOffset = getStaticOffsets().front();
1872  if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
1873  return emitError("expected result type with offset = ")
1874  << (ShapedType::isDynamic(expectedOffset)
1875  ? std::string("dynamic")
1876  : std::to_string(expectedOffset))
1877  << " instead of " << resultOffset;
1878 
1879  // Match strides in result memref type and in static_strides attribute.
1880  for (auto [idx, resultStride, expectedStride] :
1881  llvm::enumerate(resultStrides, getStaticStrides())) {
1882  if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
1883  return emitError("expected result type with stride = ")
1884  << (ShapedType::isDynamic(expectedStride)
1885  ? std::string("dynamic")
1886  : std::to_string(expectedStride))
1887  << " instead of " << resultStride << " in dim = " << idx;
1888  }
1889 
1890  return success();
1891 }
1892 
1893 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1894  Value src = getSource();
1895  auto getPrevSrc = [&]() -> Value {
1896  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1897  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1898  return prev.getSource();
1899 
1900  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1901  if (auto prev = src.getDefiningOp<CastOp>())
1902  return prev.getSource();
1903 
1904  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1905  // are 0.
1906  if (auto prev = src.getDefiningOp<SubViewOp>())
1907  if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
1908  return prev.getSource();
1909 
1910  return nullptr;
1911  };
1912 
1913  if (auto prevSrc = getPrevSrc()) {
1914  getSourceMutable().assign(prevSrc);
1915  return getResult();
1916  }
1917 
1918  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1919  if (!ShapedType::isDynamicShape(getType().getShape()) &&
1920  src.getType() == getType() && getStaticOffsets().front() == 0) {
1921  return src;
1922  }
1923 
1924  return nullptr;
1925 }
1926 
1927 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1929  constifyIndexValues(values, getType().getShape());
1930  return values;
1931 }
1932 
1933 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1934  SmallVector<OpFoldResult> values = getMixedStrides();
1935  SmallVector<int64_t> staticValues;
1936  int64_t unused;
1937  LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
1938  (void)status;
1939  assert(succeeded(status) && "could not get strides from type");
1940  constifyIndexValues(values, staticValues);
1941  return values;
1942 }
1943 
1944 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1945  SmallVector<OpFoldResult> values = getMixedOffsets();
1946  assert(values.size() == 1 &&
1947  "reinterpret_cast must have one and only one offset");
1948  SmallVector<int64_t> staticValues, unused;
1949  int64_t offset;
1950  LogicalResult status = getType().getStridesAndOffset(unused, offset);
1951  (void)status;
1952  assert(succeeded(status) && "could not get offset from type");
1953  staticValues.push_back(offset);
1954  constifyIndexValues(values, staticValues);
1955  return values[0];
1956 }
1957 
1958 namespace {
1959 /// Replace the sequence:
1960 /// ```
1961 /// base, offset, sizes, strides = extract_strided_metadata src
1962 /// dst = reinterpret_cast base to offset, sizes, strides
1963 /// ```
1964 /// With
1965 ///
1966 /// ```
1967 /// dst = memref.cast src
1968 /// ```
1969 ///
1970 /// Note: The cast operation is only inserted when the type of dst and src
1971 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
1972 ///
1973 /// This pattern also matches when the offset, sizes, and strides don't come
1974 /// directly from the `extract_strided_metadata`'s results but it can be
1975 /// statically proven that they would hold the same values.
1976 ///
1977 /// For instance, the following sequence would be replaced:
1978 /// ```
1979 /// base, offset, sizes, strides =
1980 /// extract_strided_metadata memref : memref<3x4xty>
1981 /// dst = reinterpret_cast base to 0, [3, 4], strides
1982 /// ```
1983 /// Because we know (thanks to the type of the input memref) that variable
1984 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
1985 ///
1986 /// Similarly, the following sequence would be replaced:
1987 /// ```
1988 /// c0 = arith.constant 0
1989 /// c4 = arith.constant 4
1990 /// base, offset, sizes, strides =
1991 /// extract_strided_metadata memref : memref<3x4xty>
1992 /// dst = reinterpret_cast base to c0, [3, c4], strides
1993 /// ```
1994 /// Because we know that `offset`and `c0` will hold 0
1995 /// and `c4` will hold 4.
1996 ///
1997 /// If the pattern above does not match, the input of the
1998 /// extract_strided_metadata is always folded into the input of the
1999 /// reinterpret_cast operator. This allows for dead code elimination to get rid
2000 /// of the extract_strided_metadata in some cases.
2001 struct ReinterpretCastOpExtractStridedMetadataFolder
2002  : public OpRewritePattern<ReinterpretCastOp> {
2003 public:
2005 
2006  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2007  PatternRewriter &rewriter) const override {
2008  auto extractStridedMetadata =
2009  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2010  if (!extractStridedMetadata)
2011  return failure();
2012 
2013  // Check if the reinterpret cast reconstructs a memref with the exact same
2014  // properties as the extract strided metadata.
2015  auto isReinterpretCastNoop = [&]() -> bool {
2016  // First, check that the strides are the same.
2017  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2018  op.getConstifiedMixedStrides()))
2019  return false;
2020 
2021  // Second, check the sizes.
2022  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2023  op.getConstifiedMixedSizes()))
2024  return false;
2025 
2026  // Finally, check the offset.
2027  assert(op.getMixedOffsets().size() == 1 &&
2028  "reinterpret_cast with more than one offset should have been "
2029  "rejected by the verifier");
2030  return extractStridedMetadata.getConstifiedMixedOffset() ==
2031  op.getConstifiedMixedOffset();
2032  };
2033 
2034  if (!isReinterpretCastNoop()) {
2035  // If the extract_strided_metadata / reinterpret_cast pair can't be
2036  // completely folded, then we could fold the input of the
2037  // extract_strided_metadata into the input of the reinterpret_cast
2038  // input. For some cases (e.g., static dimensions) the
2039  // the extract_strided_metadata is eliminated by dead code elimination.
2040  //
2041  // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2042  //
2043  // We can always fold the input of a extract_strided_metadata operator
2044  // to the input of a reinterpret_cast operator, because they point to
2045  // the same memory. Note that the reinterpret_cast does not use the
2046  // layout of its input memref, only its base memory pointer which is
2047  // the same as the base pointer returned by the extract_strided_metadata
2048  // operator and the base pointer of the extract_strided_metadata memref
2049  // input.
2050  rewriter.modifyOpInPlace(op, [&]() {
2051  op.getSourceMutable().assign(extractStridedMetadata.getSource());
2052  });
2053  return success();
2054  }
2055 
2056  // At this point, we know that the back and forth between extract strided
2057  // metadata and reinterpret cast is a noop. However, the final type of the
2058  // reinterpret cast may not be exactly the same as the original memref.
2059  // E.g., it could be changing a dimension from static to dynamic. Check that
2060  // here and add a cast if necessary.
2061  Type srcTy = extractStridedMetadata.getSource().getType();
2062  if (srcTy == op.getResult().getType())
2063  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2064  else
2065  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2066  extractStridedMetadata.getSource());
2067 
2068  return success();
2069  }
2070 };
2071 } // namespace
2072 
2073 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2074  MLIRContext *context) {
2075  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2076 }
2077 
2078 //===----------------------------------------------------------------------===//
2079 // Reassociative reshape ops
2080 //===----------------------------------------------------------------------===//
2081 
2082 void CollapseShapeOp::getAsmResultNames(
2083  function_ref<void(Value, StringRef)> setNameFn) {
2084  setNameFn(getResult(), "collapse_shape");
2085 }
2086 
2087 void ExpandShapeOp::getAsmResultNames(
2088  function_ref<void(Value, StringRef)> setNameFn) {
2089  setNameFn(getResult(), "expand_shape");
2090 }
2091 
2092 LogicalResult ExpandShapeOp::reifyResultShapes(
2093  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2094  reifiedResultShapes = {
2095  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2096  return success();
2097 }
2098 
2099 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2100 /// result and operand. Layout maps are verified separately.
2101 ///
2102 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2103 /// allowed in a reassocation group.
2104 static LogicalResult
2106  ArrayRef<int64_t> expandedShape,
2107  ArrayRef<ReassociationIndices> reassociation,
2108  bool allowMultipleDynamicDimsPerGroup) {
2109  // There must be one reassociation group per collapsed dimension.
2110  if (collapsedShape.size() != reassociation.size())
2111  return op->emitOpError("invalid number of reassociation groups: found ")
2112  << reassociation.size() << ", expected " << collapsedShape.size();
2113 
2114  // The next expected expanded dimension index (while iterating over
2115  // reassociation indices).
2116  int64_t nextDim = 0;
2117  for (const auto &it : llvm::enumerate(reassociation)) {
2118  ReassociationIndices group = it.value();
2119  int64_t collapsedDim = it.index();
2120 
2121  bool foundDynamic = false;
2122  for (int64_t expandedDim : group) {
2123  if (expandedDim != nextDim++)
2124  return op->emitOpError("reassociation indices must be contiguous");
2125 
2126  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2127  return op->emitOpError("reassociation index ")
2128  << expandedDim << " is out of bounds";
2129 
2130  // Check if there are multiple dynamic dims in a reassociation group.
2131  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2132  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2133  return op->emitOpError(
2134  "at most one dimension in a reassociation group may be dynamic");
2135  foundDynamic = true;
2136  }
2137  }
2138 
2139  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2140  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2141  return op->emitOpError("collapsed dim (")
2142  << collapsedDim
2143  << ") must be dynamic if and only if reassociation group is "
2144  "dynamic";
2145 
2146  // If all dims in the reassociation group are static, the size of the
2147  // collapsed dim can be verified.
2148  if (!foundDynamic) {
2149  int64_t groupSize = 1;
2150  for (int64_t expandedDim : group)
2151  groupSize *= expandedShape[expandedDim];
2152  if (groupSize != collapsedShape[collapsedDim])
2153  return op->emitOpError("collapsed dim size (")
2154  << collapsedShape[collapsedDim]
2155  << ") must equal reassociation group size (" << groupSize << ")";
2156  }
2157  }
2158 
2159  if (collapsedShape.empty()) {
2160  // Rank 0: All expanded dimensions must be 1.
2161  for (int64_t d : expandedShape)
2162  if (d != 1)
2163  return op->emitOpError(
2164  "rank 0 memrefs can only be extended/collapsed with/from ones");
2165  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2166  // Rank >= 1: Number of dimensions among all reassociation groups must match
2167  // the result memref rank.
2168  return op->emitOpError("expanded rank (")
2169  << expandedShape.size()
2170  << ") inconsistent with number of reassociation indices (" << nextDim
2171  << ")";
2172  }
2173 
2174  return success();
2175 }
2176 
2177 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2178  return getSymbolLessAffineMaps(getReassociationExprs());
2179 }
2180 
2181 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2183  getReassociationIndices());
2184 }
2185 
2186 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2187  return getSymbolLessAffineMaps(getReassociationExprs());
2188 }
2189 
2190 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2192  getReassociationIndices());
2193 }
2194 
2195 /// Compute the layout map after expanding a given source MemRef type with the
2196 /// specified reassociation indices.
2197 static FailureOr<StridedLayoutAttr>
2198 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2199  ArrayRef<ReassociationIndices> reassociation) {
2200  int64_t srcOffset;
2201  SmallVector<int64_t> srcStrides;
2202  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2203  return failure();
2204  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2205 
2206  // 1-1 mapping between srcStrides and reassociation packs.
2207  // Each srcStride starts with the given value and gets expanded according to
2208  // the proper entries in resultShape.
2209  // Example:
2210  // srcStrides = [10000, 1 , 100 ],
2211  // reassociations = [ [0], [1], [2, 3, 4]],
2212  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2213  // -> For the purpose of stride calculation, the useful sizes are:
2214  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2215  // resultStrides = [10000, 1, 600, 200, 100]
2216  // Note that a stride does not get expanded along the first entry of each
2217  // shape pack.
2218  SmallVector<int64_t> reverseResultStrides;
2219  reverseResultStrides.reserve(resultShape.size());
2220  unsigned shapeIndex = resultShape.size() - 1;
2221  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2222  ReassociationIndices reassoc = std::get<0>(it);
2223  int64_t currentStrideToExpand = std::get<1>(it);
2224  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2225  reverseResultStrides.push_back(currentStrideToExpand);
2226  currentStrideToExpand =
2227  (SaturatedInteger::wrap(currentStrideToExpand) *
2228  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2229  .asInteger();
2230  }
2231  }
2232  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2233  resultStrides.resize(resultShape.size(), 1);
2234  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2235 }
2236 
2237 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2238  MemRefType srcType, ArrayRef<int64_t> resultShape,
2239  ArrayRef<ReassociationIndices> reassociation) {
2240  if (srcType.getLayout().isIdentity()) {
2241  // If the source is contiguous (i.e., no layout map specified), so is the
2242  // result.
2243  MemRefLayoutAttrInterface layout;
2244  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2245  srcType.getMemorySpace());
2246  }
2247 
2248  // Source may not be contiguous. Compute the layout map.
2249  FailureOr<StridedLayoutAttr> computedLayout =
2250  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2251  if (failed(computedLayout))
2252  return failure();
2253  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2254  srcType.getMemorySpace());
2255 }
2256 
2257 FailureOr<SmallVector<OpFoldResult>>
2258 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2259  MemRefType expandedType,
2260  ArrayRef<ReassociationIndices> reassociation,
2261  ArrayRef<OpFoldResult> inputShape) {
2262  std::optional<SmallVector<OpFoldResult>> outputShape =
2263  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2264  inputShape);
2265  if (!outputShape)
2266  return failure();
2267  return *outputShape;
2268 }
2269 
2270 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2271  Type resultType, Value src,
2272  ArrayRef<ReassociationIndices> reassociation,
2273  ArrayRef<OpFoldResult> outputShape) {
2274  auto [staticOutputShape, dynamicOutputShape] =
2276  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2277  getReassociationIndicesAttribute(builder, reassociation),
2278  dynamicOutputShape, staticOutputShape);
2279 }
2280 
2281 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2282  Type resultType, Value src,
2283  ArrayRef<ReassociationIndices> reassociation) {
2284  SmallVector<OpFoldResult> inputShape =
2285  getMixedSizes(builder, result.location, src);
2286  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2287  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2288  builder, result.location, memrefResultTy, reassociation, inputShape);
2289  // Failure of this assertion usually indicates presence of multiple
2290  // dynamic dimensions in the same reassociation group.
2291  assert(succeeded(outputShape) && "unable to infer output shape");
2292  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2293 }
2294 
2295 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2296  ArrayRef<int64_t> resultShape, Value src,
2297  ArrayRef<ReassociationIndices> reassociation) {
2298  // Only ranked memref source values are supported.
2299  auto srcType = llvm::cast<MemRefType>(src.getType());
2300  FailureOr<MemRefType> resultType =
2301  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2302  // Failure of this assertion usually indicates a problem with the source
2303  // type, e.g., could not get strides/offset.
2304  assert(succeeded(resultType) && "could not compute layout");
2305  build(builder, result, *resultType, src, reassociation);
2306 }
2307 
2308 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2309  ArrayRef<int64_t> resultShape, Value src,
2310  ArrayRef<ReassociationIndices> reassociation,
2311  ArrayRef<OpFoldResult> outputShape) {
2312  // Only ranked memref source values are supported.
2313  auto srcType = llvm::cast<MemRefType>(src.getType());
2314  FailureOr<MemRefType> resultType =
2315  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2316  // Failure of this assertion usually indicates a problem with the source
2317  // type, e.g., could not get strides/offset.
2318  assert(succeeded(resultType) && "could not compute layout");
2319  build(builder, result, *resultType, src, reassociation, outputShape);
2320 }
2321 
2322 LogicalResult ExpandShapeOp::verify() {
2323  MemRefType srcType = getSrcType();
2324  MemRefType resultType = getResultType();
2325 
2326  if (srcType.getRank() > resultType.getRank()) {
2327  auto r0 = srcType.getRank();
2328  auto r1 = resultType.getRank();
2329  return emitOpError("has source rank ")
2330  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2331  << r0 << " > " << r1 << ").";
2332  }
2333 
2334  // Verify result shape.
2335  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2336  resultType.getShape(),
2337  getReassociationIndices(),
2338  /*allowMultipleDynamicDimsPerGroup=*/true)))
2339  return failure();
2340 
2341  // Compute expected result type (including layout map).
2342  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2343  srcType, resultType.getShape(), getReassociationIndices());
2344  if (failed(expectedResultType))
2345  return emitOpError("invalid source layout map");
2346 
2347  // Check actual result type.
2348  if (*expectedResultType != resultType)
2349  return emitOpError("expected expanded type to be ")
2350  << *expectedResultType << " but found " << resultType;
2351 
2352  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2353  return emitOpError("expected number of static shape bounds to be equal to "
2354  "the output rank (")
2355  << resultType.getRank() << ") but found "
2356  << getStaticOutputShape().size() << " inputs instead";
2357 
2358  if ((int64_t)getOutputShape().size() !=
2359  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2360  return emitOpError("mismatch in dynamic dims in output_shape and "
2361  "static_output_shape: static_output_shape has ")
2362  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2363  << " dynamic dims while output_shape has " << getOutputShape().size()
2364  << " values";
2365 
2366  // Verify if provided output shapes are in agreement with output type.
2367  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2368  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2369  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2370  if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2371  return emitOpError("invalid output shape provided at pos ") << pos;
2372  }
2373  }
2374 
2375  return success();
2376 }
2377 
2378 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2379  MLIRContext *context) {
2380  results.add<
2383 }
2384 
2385 /// Compute the layout map after collapsing a given source MemRef type with the
2386 /// specified reassociation indices.
2387 ///
2388 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2389 /// not possible to check this by inspecting a MemRefType in the general case.
2390 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2391 /// be valid (and thus accepted by this function) unless `strict = true`.
2392 static FailureOr<StridedLayoutAttr>
2393 computeCollapsedLayoutMap(MemRefType srcType,
2394  ArrayRef<ReassociationIndices> reassociation,
2395  bool strict = false) {
2396  int64_t srcOffset;
2397  SmallVector<int64_t> srcStrides;
2398  auto srcShape = srcType.getShape();
2399  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2400  return failure();
2401 
2402  // The result stride of a reassociation group is the stride of the last entry
2403  // of the reassociation. (TODO: Should be the minimum stride in the
2404  // reassociation because strides are not necessarily sorted. E.g., when using
2405  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2406  // strides are meaningless and could have any arbitrary value.
2407  SmallVector<int64_t> resultStrides;
2408  resultStrides.reserve(reassociation.size());
2409  for (const ReassociationIndices &reassoc : reassociation) {
2410  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2411  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2412  ref = ref.drop_back();
2413  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2414  resultStrides.push_back(srcStrides[ref.back()]);
2415  } else {
2416  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2417  // the corresponding stride may have to be skipped. (See above comment.)
2418  // Therefore, the result stride cannot be statically determined and must
2419  // be dynamic.
2420  resultStrides.push_back(ShapedType::kDynamic);
2421  }
2422  }
2423 
2424  // Validate that each reassociation group is contiguous.
2425  unsigned resultStrideIndex = resultStrides.size() - 1;
2426  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2427  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2428  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2429  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2430  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2431 
2432  // Both source and result stride must have the same static value. In that
2433  // case, we can be sure, that the dimensions are collapsible (because they
2434  // are contiguous).
2435  // If `strict = false` (default during op verification), we accept cases
2436  // where one or both strides are dynamic. This is best effort: We reject
2437  // ops where obviously non-contiguous dims are collapsed, but accept ops
2438  // where we cannot be sure statically. Such ops may fail at runtime. See
2439  // the op documentation for details.
2440  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2441  if (strict && (stride.saturated || srcStride.saturated))
2442  return failure();
2443 
2444  // Dimensions of size 1 should be skipped, because their strides are
2445  // meaningless and could have any arbitrary value.
2446  if (srcShape[idx - 1] == 1)
2447  continue;
2448 
2449  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2450  return failure();
2451  }
2452  }
2453  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2454 }
2455 
2456 bool CollapseShapeOp::isGuaranteedCollapsible(
2457  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2458  // MemRefs with identity layout are always collapsible.
2459  if (srcType.getLayout().isIdentity())
2460  return true;
2461 
2462  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2463  /*strict=*/true));
2464 }
2465 
2466 MemRefType CollapseShapeOp::computeCollapsedType(
2467  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2468  SmallVector<int64_t> resultShape;
2469  resultShape.reserve(reassociation.size());
2470  for (const ReassociationIndices &group : reassociation) {
2471  auto groupSize = SaturatedInteger::wrap(1);
2472  for (int64_t srcDim : group)
2473  groupSize =
2474  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2475  resultShape.push_back(groupSize.asInteger());
2476  }
2477 
2478  if (srcType.getLayout().isIdentity()) {
2479  // If the source is contiguous (i.e., no layout map specified), so is the
2480  // result.
2481  MemRefLayoutAttrInterface layout;
2482  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2483  srcType.getMemorySpace());
2484  }
2485 
2486  // Source may not be fully contiguous. Compute the layout map.
2487  // Note: Dimensions that are collapsed into a single dim are assumed to be
2488  // contiguous.
2489  FailureOr<StridedLayoutAttr> computedLayout =
2490  computeCollapsedLayoutMap(srcType, reassociation);
2491  assert(succeeded(computedLayout) &&
2492  "invalid source layout map or collapsing non-contiguous dims");
2493  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2494  srcType.getMemorySpace());
2495 }
2496 
2497 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2498  ArrayRef<ReassociationIndices> reassociation,
2499  ArrayRef<NamedAttribute> attrs) {
2500  auto srcType = llvm::cast<MemRefType>(src.getType());
2501  MemRefType resultType =
2502  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2504  getReassociationIndicesAttribute(b, reassociation));
2505  build(b, result, resultType, src, attrs);
2506 }
2507 
2508 LogicalResult CollapseShapeOp::verify() {
2509  MemRefType srcType = getSrcType();
2510  MemRefType resultType = getResultType();
2511 
2512  if (srcType.getRank() < resultType.getRank()) {
2513  auto r0 = srcType.getRank();
2514  auto r1 = resultType.getRank();
2515  return emitOpError("has source rank ")
2516  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2517  << r0 << " < " << r1 << ").";
2518  }
2519 
2520  // Verify result shape.
2521  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2522  srcType.getShape(), getReassociationIndices(),
2523  /*allowMultipleDynamicDimsPerGroup=*/true)))
2524  return failure();
2525 
2526  // Compute expected result type (including layout map).
2527  MemRefType expectedResultType;
2528  if (srcType.getLayout().isIdentity()) {
2529  // If the source is contiguous (i.e., no layout map specified), so is the
2530  // result.
2531  MemRefLayoutAttrInterface layout;
2532  expectedResultType =
2533  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2534  srcType.getMemorySpace());
2535  } else {
2536  // Source may not be fully contiguous. Compute the layout map.
2537  // Note: Dimensions that are collapsed into a single dim are assumed to be
2538  // contiguous.
2539  FailureOr<StridedLayoutAttr> computedLayout =
2540  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2541  if (failed(computedLayout))
2542  return emitOpError(
2543  "invalid source layout map or collapsing non-contiguous dims");
2544  expectedResultType =
2545  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2546  *computedLayout, srcType.getMemorySpace());
2547  }
2548 
2549  if (expectedResultType != resultType)
2550  return emitOpError("expected collapsed type to be ")
2551  << expectedResultType << " but found " << resultType;
2552 
2553  return success();
2554 }
2555 
2557  : public OpRewritePattern<CollapseShapeOp> {
2558 public:
2560 
2561  LogicalResult matchAndRewrite(CollapseShapeOp op,
2562  PatternRewriter &rewriter) const override {
2563  auto cast = op.getOperand().getDefiningOp<CastOp>();
2564  if (!cast)
2565  return failure();
2566 
2567  if (!CastOp::canFoldIntoConsumerOp(cast))
2568  return failure();
2569 
2570  Type newResultType = CollapseShapeOp::computeCollapsedType(
2571  llvm::cast<MemRefType>(cast.getOperand().getType()),
2572  op.getReassociationIndices());
2573 
2574  if (newResultType == op.getResultType()) {
2575  rewriter.modifyOpInPlace(
2576  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2577  } else {
2578  Value newOp = rewriter.create<CollapseShapeOp>(
2579  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2580  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2581  }
2582  return success();
2583  }
2584 };
2585 
2586 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2587  MLIRContext *context) {
2588  results.add<
2590  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2591  memref::DimOp, MemRefType>,
2593 }
2594 
2595 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2596  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2597  adaptor.getOperands());
2598 }
2599 
2600 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2601  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2602  adaptor.getOperands());
2603 }
2604 
2605 //===----------------------------------------------------------------------===//
2606 // ReshapeOp
2607 //===----------------------------------------------------------------------===//
2608 
2609 void ReshapeOp::getAsmResultNames(
2610  function_ref<void(Value, StringRef)> setNameFn) {
2611  setNameFn(getResult(), "reshape");
2612 }
2613 
2614 LogicalResult ReshapeOp::verify() {
2615  Type operandType = getSource().getType();
2616  Type resultType = getResult().getType();
2617 
2618  Type operandElementType =
2619  llvm::cast<ShapedType>(operandType).getElementType();
2620  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2621  if (operandElementType != resultElementType)
2622  return emitOpError("element types of source and destination memref "
2623  "types should be the same");
2624 
2625  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2626  if (!operandMemRefType.getLayout().isIdentity())
2627  return emitOpError("source memref type should have identity affine map");
2628 
2629  int64_t shapeSize =
2630  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2631  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2632  if (resultMemRefType) {
2633  if (!resultMemRefType.getLayout().isIdentity())
2634  return emitOpError("result memref type should have identity affine map");
2635  if (shapeSize == ShapedType::kDynamic)
2636  return emitOpError("cannot use shape operand with dynamic length to "
2637  "reshape to statically-ranked memref type");
2638  if (shapeSize != resultMemRefType.getRank())
2639  return emitOpError(
2640  "length of shape operand differs from the result's memref rank");
2641  }
2642  return success();
2643 }
2644 
2645 //===----------------------------------------------------------------------===//
2646 // StoreOp
2647 //===----------------------------------------------------------------------===//
2648 
2649 LogicalResult StoreOp::verify() {
2650  if (getNumOperands() != 2 + getMemRefType().getRank())
2651  return emitOpError("store index operand count not equal to memref rank");
2652 
2653  return success();
2654 }
2655 
2656 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2657  SmallVectorImpl<OpFoldResult> &results) {
2658  /// store(memrefcast) -> store
2659  return foldMemRefCast(*this, getValueToStore());
2660 }
2661 
2662 //===----------------------------------------------------------------------===//
2663 // SubViewOp
2664 //===----------------------------------------------------------------------===//
2665 
2666 void SubViewOp::getAsmResultNames(
2667  function_ref<void(Value, StringRef)> setNameFn) {
2668  setNameFn(getResult(), "subview");
2669 }
2670 
2671 /// A subview result type can be fully inferred from the source type and the
2672 /// static representation of offsets, sizes and strides. Special sentinels
2673 /// encode the dynamic case.
2674 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2675  ArrayRef<int64_t> staticOffsets,
2676  ArrayRef<int64_t> staticSizes,
2677  ArrayRef<int64_t> staticStrides) {
2678  unsigned rank = sourceMemRefType.getRank();
2679  (void)rank;
2680  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2681  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2682  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2683 
2684  // Extract source offset and strides.
2685  auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2686 
2687  // Compute target offset whose value is:
2688  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2689  int64_t targetOffset = sourceOffset;
2690  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2691  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2692  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2693  SaturatedInteger::wrap(staticOffset) *
2694  SaturatedInteger::wrap(sourceStride))
2695  .asInteger();
2696  }
2697 
2698  // Compute target stride whose value is:
2699  // `sourceStrides_i * staticStrides_i`.
2700  SmallVector<int64_t, 4> targetStrides;
2701  targetStrides.reserve(staticOffsets.size());
2702  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2703  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2704  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2705  SaturatedInteger::wrap(staticStride))
2706  .asInteger());
2707  }
2708 
2709  // The type is now known.
2710  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2711  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2712  targetOffset, targetStrides),
2713  sourceMemRefType.getMemorySpace());
2714 }
2715 
2716 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2717  ArrayRef<OpFoldResult> offsets,
2718  ArrayRef<OpFoldResult> sizes,
2719  ArrayRef<OpFoldResult> strides) {
2720  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2721  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2722  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2723  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2724  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2725  if (!hasValidSizesOffsets(staticOffsets))
2726  return {};
2727  if (!hasValidSizesOffsets(staticSizes))
2728  return {};
2729  if (!hasValidStrides(staticStrides))
2730  return {};
2731  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2732  staticSizes, staticStrides);
2733 }
2734 
2735 MemRefType SubViewOp::inferRankReducedResultType(
2736  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2737  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2738  ArrayRef<int64_t> strides) {
2739  MemRefType inferredType =
2740  inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2741  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2742  "expected ");
2743  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2744  return inferredType;
2745 
2746  // Compute which dimensions are dropped.
2747  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2748  computeRankReductionMask(inferredType.getShape(), resultShape);
2749  assert(dimsToProject.has_value() && "invalid rank reduction");
2750 
2751  // Compute the layout and result type.
2752  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2753  SmallVector<int64_t> rankReducedStrides;
2754  rankReducedStrides.reserve(resultShape.size());
2755  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2756  if (!dimsToProject->contains(idx))
2757  rankReducedStrides.push_back(value);
2758  }
2759  return MemRefType::get(resultShape, inferredType.getElementType(),
2760  StridedLayoutAttr::get(inferredLayout.getContext(),
2761  inferredLayout.getOffset(),
2762  rankReducedStrides),
2763  inferredType.getMemorySpace());
2764 }
2765 
2766 MemRefType SubViewOp::inferRankReducedResultType(
2767  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2769  ArrayRef<OpFoldResult> strides) {
2770  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2771  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2772  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2773  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2774  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2775  return SubViewOp::inferRankReducedResultType(
2776  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2777  staticStrides);
2778 }
2779 
2780 // Build a SubViewOp with mixed static and dynamic entries and custom result
2781 // type. If the type passed is nullptr, it is inferred.
2782 void SubViewOp::build(OpBuilder &b, OperationState &result,
2783  MemRefType resultType, Value source,
2784  ArrayRef<OpFoldResult> offsets,
2785  ArrayRef<OpFoldResult> sizes,
2786  ArrayRef<OpFoldResult> strides,
2787  ArrayRef<NamedAttribute> attrs) {
2788  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2789  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2790  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2791  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2792  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2793  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2794  // Structuring implementation this way avoids duplication between builders.
2795  if (!resultType) {
2796  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2797  staticSizes, staticStrides);
2798  }
2799  result.addAttributes(attrs);
2800  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2801  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2802  b.getDenseI64ArrayAttr(staticSizes),
2803  b.getDenseI64ArrayAttr(staticStrides));
2804 }
2805 
2806 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2807 // type.
2808 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2809  ArrayRef<OpFoldResult> offsets,
2810  ArrayRef<OpFoldResult> sizes,
2811  ArrayRef<OpFoldResult> strides,
2812  ArrayRef<NamedAttribute> attrs) {
2813  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2814 }
2815 
2816 // Build a SubViewOp with static entries and inferred result type.
2817 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2818  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2819  ArrayRef<int64_t> strides,
2820  ArrayRef<NamedAttribute> attrs) {
2821  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2822  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2823  return b.getI64IntegerAttr(v);
2824  }));
2825  SmallVector<OpFoldResult> sizeValues =
2826  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2827  return b.getI64IntegerAttr(v);
2828  }));
2829  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2830  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2831  return b.getI64IntegerAttr(v);
2832  }));
2833  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2834 }
2835 
2836 // Build a SubViewOp with dynamic entries and custom result type. If the
2837 // type passed is nullptr, it is inferred.
2838 void SubViewOp::build(OpBuilder &b, OperationState &result,
2839  MemRefType resultType, Value source,
2840  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2841  ArrayRef<int64_t> strides,
2842  ArrayRef<NamedAttribute> attrs) {
2843  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2844  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2845  return b.getI64IntegerAttr(v);
2846  }));
2847  SmallVector<OpFoldResult> sizeValues =
2848  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2849  return b.getI64IntegerAttr(v);
2850  }));
2851  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2852  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2853  return b.getI64IntegerAttr(v);
2854  }));
2855  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2856  attrs);
2857 }
2858 
2859 // Build a SubViewOp with dynamic entries and custom result type. If the type
2860 // passed is nullptr, it is inferred.
2861 void SubViewOp::build(OpBuilder &b, OperationState &result,
2862  MemRefType resultType, Value source, ValueRange offsets,
2863  ValueRange sizes, ValueRange strides,
2864  ArrayRef<NamedAttribute> attrs) {
2865  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2866  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2867  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2868  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2869  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2870  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2871  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2872 }
2873 
2874 // Build a SubViewOp with dynamic entries and inferred result type.
2875 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2876  ValueRange offsets, ValueRange sizes, ValueRange strides,
2877  ArrayRef<NamedAttribute> attrs) {
2878  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2879 }
2880 
2881 /// For ViewLikeOpInterface.
2882 Value SubViewOp::getViewSource() { return getSource(); }
2883 
2884 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2885 /// static value).
2886 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2887  int64_t t1Offset, t2Offset;
2888  SmallVector<int64_t> t1Strides, t2Strides;
2889  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2890  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2891  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2892 }
2893 
2894 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2895 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2896 /// marked as dropped in `droppedDims`.
2897 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2898  const llvm::SmallBitVector &droppedDims) {
2899  assert(size_t(t1.getRank()) == droppedDims.size() &&
2900  "incorrect number of bits");
2901  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2902  "incorrect number of dropped dims");
2903  int64_t t1Offset, t2Offset;
2904  SmallVector<int64_t> t1Strides, t2Strides;
2905  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2906  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2907  if (failed(res1) || failed(res2))
2908  return false;
2909  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2910  if (droppedDims[i])
2911  continue;
2912  if (t1Strides[i] != t2Strides[j])
2913  return false;
2914  ++j;
2915  }
2916  return true;
2917 }
2918 
2920  Operation *op, Type expectedType) {
2921  auto memrefType = llvm::cast<ShapedType>(expectedType);
2922  switch (result) {
2924  return success();
2926  return op->emitError("expected result rank to be smaller or equal to ")
2927  << "the source rank. ";
2929  return op->emitError("expected result type to be ")
2930  << expectedType
2931  << " or a rank-reduced version. (mismatch of result sizes) ";
2933  return op->emitError("expected result element type to be ")
2934  << memrefType.getElementType();
2936  return op->emitError("expected result and source memory spaces to match.");
2938  return op->emitError("expected result type to be ")
2939  << expectedType
2940  << " or a rank-reduced version. (mismatch of result layout) ";
2941  }
2942  llvm_unreachable("unexpected subview verification result");
2943 }
2944 
2945 /// Verifier for SubViewOp.
2946 LogicalResult SubViewOp::verify() {
2947  MemRefType baseType = getSourceType();
2948  MemRefType subViewType = getType();
2949  ArrayRef<int64_t> staticOffsets = getStaticOffsets();
2950  ArrayRef<int64_t> staticSizes = getStaticSizes();
2951  ArrayRef<int64_t> staticStrides = getStaticStrides();
2952 
2953  // The base memref and the view memref should be in the same memory space.
2954  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2955  return emitError("different memory spaces specified for base memref "
2956  "type ")
2957  << baseType << " and subview memref type " << subViewType;
2958 
2959  // Verify that the base memref type has a strided layout map.
2960  if (!baseType.isStrided())
2961  return emitError("base type ") << baseType << " is not strided";
2962 
2963  // Compute the expected result type, assuming that there are no rank
2964  // reductions.
2965  MemRefType expectedType = SubViewOp::inferResultType(
2966  baseType, staticOffsets, staticSizes, staticStrides);
2967 
2968  // Verify all properties of a shaped type: rank, element type and dimension
2969  // sizes. This takes into account potential rank reductions.
2970  auto shapedTypeVerification = isRankReducedType(
2971  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2972  if (shapedTypeVerification != SliceVerificationResult::Success)
2973  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2974 
2975  // Make sure that the memory space did not change.
2976  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2978  *this, expectedType);
2979 
2980  // Verify the offset of the layout map.
2981  if (!haveCompatibleOffsets(expectedType, subViewType))
2983  *this, expectedType);
2984 
2985  // The only thing that's left to verify now are the strides. First, compute
2986  // the unused dimensions due to rank reductions. We have to look at sizes and
2987  // strides to decide which dimensions were dropped. This function also
2988  // partially verifies strides in case of rank reductions.
2989  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
2990  getMixedSizes());
2991  if (failed(unusedDims))
2993  *this, expectedType);
2994 
2995  // Strides must match.
2996  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
2998  *this, expectedType);
2999 
3000  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3001  // to the base memref.
3002  SliceBoundsVerificationResult boundsResult =
3003  verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3004  staticStrides, /*generateErrorMessage=*/true);
3005  if (!boundsResult.isValid)
3006  return getOperation()->emitError(boundsResult.errorMessage);
3007 
3008  return success();
3009 }
3010 
3011 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
3012  return os << "range " << range.offset << ":" << range.size << ":"
3013  << range.stride;
3014 }
3015 
3016 /// Return the list of Range (i.e. offset, size, stride). Each Range
3017 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3018 /// with `b` at location `loc`.
3019 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3020  OpBuilder &b, Location loc) {
3021  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3022  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3023  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3025  unsigned rank = ranks[0];
3026  res.reserve(rank);
3027  for (unsigned idx = 0; idx < rank; ++idx) {
3028  Value offset =
3029  op.isDynamicOffset(idx)
3030  ? op.getDynamicOffset(idx)
3031  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
3032  Value size =
3033  op.isDynamicSize(idx)
3034  ? op.getDynamicSize(idx)
3035  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
3036  Value stride =
3037  op.isDynamicStride(idx)
3038  ? op.getDynamicStride(idx)
3039  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
3040  res.emplace_back(Range{offset, size, stride});
3041  }
3042  return res;
3043 }
3044 
3045 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3046 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3047 /// the rank of the inferred result type if `currentResultType` is lower rank
3048 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3049 /// together with the result type. In this case, it is important to compute
3050 /// the dropped dimensions using `currentSourceType` whose strides align with
3051 /// `currentResultType`.
3053  MemRefType currentResultType, MemRefType currentSourceType,
3054  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3055  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3056  MemRefType nonRankReducedType = SubViewOp::inferResultType(
3057  sourceType, mixedOffsets, mixedSizes, mixedStrides);
3058  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3059  currentSourceType, currentResultType, mixedSizes);
3060  if (failed(unusedDims))
3061  return nullptr;
3062 
3063  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3064  SmallVector<int64_t> shape, strides;
3065  unsigned numDimsAfterReduction =
3066  nonRankReducedType.getRank() - unusedDims->count();
3067  shape.reserve(numDimsAfterReduction);
3068  strides.reserve(numDimsAfterReduction);
3069  for (const auto &[idx, size, stride] :
3070  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3071  nonRankReducedType.getShape(), layout.getStrides())) {
3072  if (unusedDims->test(idx))
3073  continue;
3074  shape.push_back(size);
3075  strides.push_back(stride);
3076  }
3077 
3078  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3079  StridedLayoutAttr::get(sourceType.getContext(),
3080  layout.getOffset(), strides),
3081  nonRankReducedType.getMemorySpace());
3082 }
3083 
3085  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3086  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3087  unsigned rank = memrefType.getRank();
3088  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3089  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3090  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3091  MemRefType targetType = SubViewOp::inferRankReducedResultType(
3092  targetShape, memrefType, offsets, sizes, strides);
3093  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3094  sizes, strides);
3095 }
3096 
3097 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3098  Value value,
3099  ArrayRef<int64_t> desiredShape) {
3100  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3101  assert(sourceMemrefType && "not a ranked memref type");
3102  auto sourceShape = sourceMemrefType.getShape();
3103  if (sourceShape.equals(desiredShape))
3104  return value;
3105  auto maybeRankReductionMask =
3106  mlir::computeRankReductionMask(sourceShape, desiredShape);
3107  if (!maybeRankReductionMask)
3108  return failure();
3109  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3110 }
3111 
3112 /// Helper method to check if a `subview` operation is trivially a no-op. This
3113 /// is the case if the all offsets are zero, all strides are 1, and the source
3114 /// shape is same as the size of the subview. In such cases, the subview can
3115 /// be folded into its source.
3116 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3117  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3118  return false;
3119 
3120  auto mixedOffsets = subViewOp.getMixedOffsets();
3121  auto mixedSizes = subViewOp.getMixedSizes();
3122  auto mixedStrides = subViewOp.getMixedStrides();
3123 
3124  // Check offsets are zero.
3125  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3126  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3127  return !intValue || intValue.value() != 0;
3128  }))
3129  return false;
3130 
3131  // Check strides are one.
3132  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3133  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3134  return !intValue || intValue.value() != 1;
3135  }))
3136  return false;
3137 
3138  // Check all size values are static and matches the (static) source shape.
3139  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3140  for (const auto &size : llvm::enumerate(mixedSizes)) {
3141  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3142  if (!intValue || *intValue != sourceShape[size.index()])
3143  return false;
3144  }
3145  // All conditions met. The `SubViewOp` is foldable as a no-op.
3146  return true;
3147 }
3148 
3149 namespace {
3150 /// Pattern to rewrite a subview op with MemRefCast arguments.
3151 /// This essentially pushes memref.cast past its consuming subview when
3152 /// `canFoldIntoConsumerOp` is true.
3153 ///
3154 /// Example:
3155 /// ```
3156 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3157 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3158 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3159 /// ```
3160 /// is rewritten into:
3161 /// ```
3162 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3163 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3164 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3165 /// ```
3166 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3167 public:
3169 
3170  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3171  PatternRewriter &rewriter) const override {
3172  // Any constant operand, just return to let SubViewOpConstantFolder kick
3173  // in.
3174  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3175  return matchPattern(operand, matchConstantIndex());
3176  }))
3177  return failure();
3178 
3179  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3180  if (!castOp)
3181  return failure();
3182 
3183  if (!CastOp::canFoldIntoConsumerOp(castOp))
3184  return failure();
3185 
3186  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3187  // the MemRefCastOp source operand type to infer the result type and the
3188  // current SubViewOp source operand type to compute the dropped dimensions
3189  // if the operation is rank-reducing.
3190  auto resultType = getCanonicalSubViewResultType(
3191  subViewOp.getType(), subViewOp.getSourceType(),
3192  llvm::cast<MemRefType>(castOp.getSource().getType()),
3193  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3194  subViewOp.getMixedStrides());
3195  if (!resultType)
3196  return failure();
3197 
3198  Value newSubView = rewriter.create<SubViewOp>(
3199  subViewOp.getLoc(), resultType, castOp.getSource(),
3200  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3201  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3202  subViewOp.getStaticStrides());
3203  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3204  newSubView);
3205  return success();
3206  }
3207 };
3208 
3209 /// Canonicalize subview ops that are no-ops. When the source shape is not
3210 /// same as a result shape due to use of `affine_map`.
3211 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3212 public:
3214 
3215  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3216  PatternRewriter &rewriter) const override {
3217  if (!isTrivialSubViewOp(subViewOp))
3218  return failure();
3219  if (subViewOp.getSourceType() == subViewOp.getType()) {
3220  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3221  return success();
3222  }
3223  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3224  subViewOp.getSource());
3225  return success();
3226  }
3227 };
3228 } // namespace
3229 
3230 /// Return the canonical type of the result of a subview.
3232  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3233  ArrayRef<OpFoldResult> mixedSizes,
3234  ArrayRef<OpFoldResult> mixedStrides) {
3235  // Infer a memref type without taking into account any rank reductions.
3236  MemRefType resTy = SubViewOp::inferResultType(
3237  op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3238  if (!resTy)
3239  return {};
3240  MemRefType nonReducedType = resTy;
3241 
3242  // Directly return the non-rank reduced type if there are no dropped dims.
3243  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3244  if (droppedDims.none())
3245  return nonReducedType;
3246 
3247  // Take the strides and offset from the non-rank reduced type.
3248  auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3249 
3250  // Drop dims from shape and strides.
3251  SmallVector<int64_t> targetShape;
3252  SmallVector<int64_t> targetStrides;
3253  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3254  if (droppedDims.test(i))
3255  continue;
3256  targetStrides.push_back(nonReducedStrides[i]);
3257  targetShape.push_back(nonReducedType.getDimSize(i));
3258  }
3259 
3260  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3261  StridedLayoutAttr::get(nonReducedType.getContext(),
3262  offset, targetStrides),
3263  nonReducedType.getMemorySpace());
3264  }
3265 };
3266 
3267 /// A canonicalizer wrapper to replace SubViewOps.
3269  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3270  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3271  }
3272 };
3273 
3274 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3275  MLIRContext *context) {
3276  results
3279  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3280 }
3281 
3282 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3283  MemRefType sourceMemrefType = getSource().getType();
3284  MemRefType resultMemrefType = getResult().getType();
3285  auto resultLayout =
3286  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3287 
3288  if (resultMemrefType == sourceMemrefType &&
3289  resultMemrefType.hasStaticShape() &&
3290  (!resultLayout || resultLayout.hasStaticLayout())) {
3291  return getViewSource();
3292  }
3293 
3294  // Fold subview(subview(x)), where both subviews have the same size and the
3295  // second subview's offsets are all zero. (I.e., the second subview is a
3296  // no-op.)
3297  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3298  auto srcSizes = srcSubview.getMixedSizes();
3299  auto sizes = getMixedSizes();
3300  auto offsets = getMixedOffsets();
3301  bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3302  auto strides = getMixedStrides();
3303  bool allStridesOne = llvm::all_of(strides, isOneInteger);
3304  bool allSizesSame = llvm::equal(sizes, srcSizes);
3305  if (allOffsetsZero && allStridesOne && allSizesSame &&
3306  resultMemrefType == sourceMemrefType)
3307  return getViewSource();
3308  }
3309 
3310  return {};
3311 }
3312 
3313 //===----------------------------------------------------------------------===//
3314 // TransposeOp
3315 //===----------------------------------------------------------------------===//
3316 
3317 void TransposeOp::getAsmResultNames(
3318  function_ref<void(Value, StringRef)> setNameFn) {
3319  setNameFn(getResult(), "transpose");
3320 }
3321 
3322 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3323 static MemRefType inferTransposeResultType(MemRefType memRefType,
3324  AffineMap permutationMap) {
3325  auto originalSizes = memRefType.getShape();
3326  auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3327  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3328 
3329  // Compute permuted sizes and strides.
3330  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3331  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3332 
3333  return MemRefType::Builder(memRefType)
3334  .setShape(sizes)
3335  .setLayout(
3336  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3337 }
3338 
3339 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3340  AffineMapAttr permutation,
3341  ArrayRef<NamedAttribute> attrs) {
3342  auto permutationMap = permutation.getValue();
3343  assert(permutationMap);
3344 
3345  auto memRefType = llvm::cast<MemRefType>(in.getType());
3346  // Compute result type.
3347  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3348 
3349  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3350  build(b, result, resultType, in, attrs);
3351 }
3352 
3353 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3355  p << " " << getIn() << " " << getPermutation();
3356  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3357  p << " : " << getIn().getType() << " to " << getType();
3358 }
3359 
3360 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3362  AffineMap permutation;
3363  MemRefType srcType, dstType;
3364  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3365  parser.parseOptionalAttrDict(result.attributes) ||
3366  parser.parseColonType(srcType) ||
3367  parser.resolveOperand(in, srcType, result.operands) ||
3368  parser.parseKeywordType("to", dstType) ||
3369  parser.addTypeToList(dstType, result.types))
3370  return failure();
3371 
3372  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3373  AffineMapAttr::get(permutation));
3374  return success();
3375 }
3376 
3377 LogicalResult TransposeOp::verify() {
3378  if (!getPermutation().isPermutation())
3379  return emitOpError("expected a permutation map");
3380  if (getPermutation().getNumDims() != getIn().getType().getRank())
3381  return emitOpError("expected a permutation map of same rank as the input");
3382 
3383  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3384  auto resultType = llvm::cast<MemRefType>(getType());
3385  auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3386  .canonicalizeStridedLayout();
3387 
3388  if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3389  return emitOpError("result type ")
3390  << resultType
3391  << " is not equivalent to the canonical transposed input type "
3392  << canonicalResultType;
3393  return success();
3394 }
3395 
3396 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3397  // First check for identity permutation, we can fold it away if input and
3398  // result types are identical already.
3399  if (getPermutation().isIdentity() && getType() == getIn().getType())
3400  return getIn();
3401  // Fold two consecutive memref.transpose Ops into one by composing their
3402  // permutation maps.
3403  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3404  AffineMap composedPermutation =
3405  getPermutation().compose(otherTransposeOp.getPermutation());
3406  getInMutable().assign(otherTransposeOp.getIn());
3407  setPermutation(composedPermutation);
3408  return getResult();
3409  }
3410  return {};
3411 }
3412 
3413 //===----------------------------------------------------------------------===//
3414 // ViewOp
3415 //===----------------------------------------------------------------------===//
3416 
3417 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3418  setNameFn(getResult(), "view");
3419 }
3420 
3421 LogicalResult ViewOp::verify() {
3422  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3423  auto viewType = getType();
3424 
3425  // The base memref should have identity layout map (or none).
3426  if (!baseType.getLayout().isIdentity())
3427  return emitError("unsupported map for base memref type ") << baseType;
3428 
3429  // The result memref should have identity layout map (or none).
3430  if (!viewType.getLayout().isIdentity())
3431  return emitError("unsupported map for result memref type ") << viewType;
3432 
3433  // The base memref and the view memref should be in the same memory space.
3434  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3435  return emitError("different memory spaces specified for base memref "
3436  "type ")
3437  << baseType << " and view memref type " << viewType;
3438 
3439  // Verify that we have the correct number of sizes for the result type.
3440  unsigned numDynamicDims = viewType.getNumDynamicDims();
3441  if (getSizes().size() != numDynamicDims)
3442  return emitError("incorrect number of size operands for type ") << viewType;
3443 
3444  return success();
3445 }
3446 
3447 Value ViewOp::getViewSource() { return getSource(); }
3448 
3449 namespace {
3450 
3451 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3453 
3454  LogicalResult matchAndRewrite(ViewOp viewOp,
3455  PatternRewriter &rewriter) const override {
3456  // Return if none of the operands are constants.
3457  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3458  return matchPattern(operand, matchConstantIndex());
3459  }))
3460  return failure();
3461 
3462  // Get result memref type.
3463  auto memrefType = viewOp.getType();
3464 
3465  // Get offset from old memref view type 'memRefType'.
3466  int64_t oldOffset;
3467  SmallVector<int64_t, 4> oldStrides;
3468  if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3469  return failure();
3470  assert(oldOffset == 0 && "Expected 0 offset");
3471 
3472  SmallVector<Value, 4> newOperands;
3473 
3474  // Offset cannot be folded into result type.
3475 
3476  // Fold any dynamic dim operands which are produced by a constant.
3477  SmallVector<int64_t, 4> newShapeConstants;
3478  newShapeConstants.reserve(memrefType.getRank());
3479 
3480  unsigned dynamicDimPos = 0;
3481  unsigned rank = memrefType.getRank();
3482  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3483  int64_t dimSize = memrefType.getDimSize(dim);
3484  // If this is already static dimension, keep it.
3485  if (!ShapedType::isDynamic(dimSize)) {
3486  newShapeConstants.push_back(dimSize);
3487  continue;
3488  }
3489  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3490  if (auto constantIndexOp =
3491  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3492  // Dynamic shape dimension will be folded.
3493  newShapeConstants.push_back(constantIndexOp.value());
3494  } else {
3495  // Dynamic shape dimension not folded; copy operand from old memref.
3496  newShapeConstants.push_back(dimSize);
3497  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3498  }
3499  dynamicDimPos++;
3500  }
3501 
3502  // Create new memref type with constant folded dims.
3503  MemRefType newMemRefType =
3504  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3505  // Nothing new, don't fold.
3506  if (newMemRefType == memrefType)
3507  return failure();
3508 
3509  // Create new ViewOp.
3510  auto newViewOp = rewriter.create<ViewOp>(
3511  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3512  viewOp.getByteShift(), newOperands);
3513  // Insert a cast so we have the same type as the old memref type.
3514  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3515  return success();
3516  }
3517 };
3518 
3519 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3521 
3522  LogicalResult matchAndRewrite(ViewOp viewOp,
3523  PatternRewriter &rewriter) const override {
3524  Value memrefOperand = viewOp.getOperand(0);
3525  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3526  if (!memrefCastOp)
3527  return failure();
3528  Value allocOperand = memrefCastOp.getOperand();
3529  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3530  if (!allocOp)
3531  return failure();
3532  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3533  viewOp.getByteShift(),
3534  viewOp.getSizes());
3535  return success();
3536  }
3537 };
3538 
3539 } // namespace
3540 
3541 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3542  MLIRContext *context) {
3543  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3544 }
3545 
3546 //===----------------------------------------------------------------------===//
3547 // AtomicRMWOp
3548 //===----------------------------------------------------------------------===//
3549 
3550 LogicalResult AtomicRMWOp::verify() {
3551  if (getMemRefType().getRank() != getNumOperands() - 2)
3552  return emitOpError(
3553  "expects the number of subscripts to be equal to memref rank");
3554  switch (getKind()) {
3555  case arith::AtomicRMWKind::addf:
3556  case arith::AtomicRMWKind::maximumf:
3557  case arith::AtomicRMWKind::minimumf:
3558  case arith::AtomicRMWKind::mulf:
3559  if (!llvm::isa<FloatType>(getValue().getType()))
3560  return emitOpError() << "with kind '"
3561  << arith::stringifyAtomicRMWKind(getKind())
3562  << "' expects a floating-point type";
3563  break;
3564  case arith::AtomicRMWKind::addi:
3565  case arith::AtomicRMWKind::maxs:
3566  case arith::AtomicRMWKind::maxu:
3567  case arith::AtomicRMWKind::mins:
3568  case arith::AtomicRMWKind::minu:
3569  case arith::AtomicRMWKind::muli:
3570  case arith::AtomicRMWKind::ori:
3571  case arith::AtomicRMWKind::andi:
3572  if (!llvm::isa<IntegerType>(getValue().getType()))
3573  return emitOpError() << "with kind '"
3574  << arith::stringifyAtomicRMWKind(getKind())
3575  << "' expects an integer type";
3576  break;
3577  default:
3578  break;
3579  }
3580  return success();
3581 }
3582 
3583 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3584  /// atomicrmw(memrefcast) -> atomicrmw
3585  if (succeeded(foldMemRefCast(*this, getValue())))
3586  return getResult();
3587  return OpFoldResult();
3588 }
3589 
3590 //===----------------------------------------------------------------------===//
3591 // TableGen'd op method definitions
3592 //===----------------------------------------------------------------------===//
3593 
3594 #define GET_OP_CLASSES
3595 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition: MemRefOps.cpp:96
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1509
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2105
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:377
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3323
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:358
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2886
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1347
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3052
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
Definition: MemRefOps.cpp:2919
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1523
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:887
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3116
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:400
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2897
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:872
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2198
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2393
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:129
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:143
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:252
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:165
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:110
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:98
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:53
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:193
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:551
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:227
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:810
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:68
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3084
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:325
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3019
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:448
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:451
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:408
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:411
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2561
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3268
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3269
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3231
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3232
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.