MLIR  22.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::memref;
29 
30 /// Materialize a single constant operation from a given attribute value with
31 /// the desired resultant type.
33  Attribute value, Type type,
34  Location loc) {
35  return arith::ConstantOp::materialize(builder, value, type, loc);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Common canonicalization pattern support logic
40 //===----------------------------------------------------------------------===//
41 
42 /// This is a common class used for patterns of the form
43 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
44 /// into the root operation directly.
45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46  bool folded = false;
47  for (OpOperand &operand : op->getOpOperands()) {
48  auto cast = operand.get().getDefiningOp<CastOp>();
49  if (cast && operand.get() != inner &&
50  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51  operand.set(cast.getOperand());
52  folded = true;
53  }
54  }
55  return success(folded);
56 }
57 
58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
59 /// type.
61  if (auto memref = llvm::dyn_cast<MemRefType>(type))
62  return RankedTensorType::get(memref.getShape(), memref.getElementType());
63  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
64  return UnrankedTensorType::get(memref.getElementType());
65  return NoneType::get(type.getContext());
66 }
67 
69  int64_t dim) {
70  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that sets values[i] to constValues[i] if the latter is a
91 /// static value, as indicated by ShapedType::kDynamic.
92 ///
93 /// If constValues[i] is dynamic, tries to extract a constant value from
94 /// value[i] to allow for additional folding opportunities. Also convertes all
95 /// existing attributes to index attributes. (They may be i64 attributes.)
97  ArrayRef<int64_t> constValues) {
98  assert(constValues.size() == values.size() &&
99  "incorrect number of const values");
100  for (auto [i, cstVal] : llvm::enumerate(constValues)) {
101  Builder builder(values[i].getContext());
102  if (ShapedType::isStatic(cstVal)) {
103  // Constant value is known, use it directly.
104  values[i] = builder.getIndexAttr(cstVal);
105  continue;
106  }
107  if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
108  // Try to extract a constant or convert an existing to index.
109  values[i] = builder.getIndexAttr(*cst);
110  }
111  }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // AllocOp / AllocaOp
116 //===----------------------------------------------------------------------===//
117 
118 void AllocOp::getAsmResultNames(
119  function_ref<void(Value, StringRef)> setNameFn) {
120  setNameFn(getResult(), "alloc");
121 }
122 
123 void AllocaOp::getAsmResultNames(
124  function_ref<void(Value, StringRef)> setNameFn) {
125  setNameFn(getResult(), "alloca");
126 }
127 
128 template <typename AllocLikeOp>
129 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
130  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
131  "applies to only alloc or alloca");
132  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
133  if (!memRefType)
134  return op.emitOpError("result must be a memref");
135 
136  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
137  return op.emitOpError("dimension operand count does not equal memref "
138  "dynamic dimension count");
139 
140  unsigned numSymbols = 0;
141  if (!memRefType.getLayout().isIdentity())
142  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
143  if (op.getSymbolOperands().size() != numSymbols)
144  return op.emitOpError("symbol operand count does not equal memref symbol "
145  "count: expected ")
146  << numSymbols << ", got " << op.getSymbolOperands().size();
147 
148  return success();
149 }
150 
151 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
152 
153 LogicalResult AllocaOp::verify() {
154  // An alloca op needs to have an ancestor with an allocation scope trait.
155  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
156  return emitOpError(
157  "requires an ancestor op with AutomaticAllocationScope trait");
158 
159  return verifyAllocLikeOp(*this);
160 }
161 
162 namespace {
163 /// Fold constant dimensions into an alloc like operation.
164 template <typename AllocLikeOp>
165 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
167 
168  LogicalResult matchAndRewrite(AllocLikeOp alloc,
169  PatternRewriter &rewriter) const override {
170  // Check to see if any dimensions operands are constants. If so, we can
171  // substitute and drop them.
172  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
173  APInt constSizeArg;
174  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
175  return false;
176  return constSizeArg.isNonNegative();
177  }))
178  return failure();
179 
180  auto memrefType = alloc.getType();
181 
182  // Ok, we have one or more constant operands. Collect the non-constant ones
183  // and keep track of the resultant memref type to build.
184  SmallVector<int64_t, 4> newShapeConstants;
185  newShapeConstants.reserve(memrefType.getRank());
186  SmallVector<Value, 4> dynamicSizes;
187 
188  unsigned dynamicDimPos = 0;
189  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
190  int64_t dimSize = memrefType.getDimSize(dim);
191  // If this is already static dimension, keep it.
192  if (ShapedType::isStatic(dimSize)) {
193  newShapeConstants.push_back(dimSize);
194  continue;
195  }
196  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
197  APInt constSizeArg;
198  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
199  constSizeArg.isNonNegative()) {
200  // Dynamic shape dimension will be folded.
201  newShapeConstants.push_back(constSizeArg.getZExtValue());
202  } else {
203  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
204  newShapeConstants.push_back(ShapedType::kDynamic);
205  dynamicSizes.push_back(dynamicSize);
206  }
207  dynamicDimPos++;
208  }
209 
210  // Create new memref type (which will have fewer dynamic dimensions).
211  MemRefType newMemRefType =
212  MemRefType::Builder(memrefType).setShape(newShapeConstants);
213  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
214 
215  // Create and insert the alloc op for the new memref.
216  auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
217  dynamicSizes, alloc.getSymbolOperands(),
218  alloc.getAlignmentAttr());
219  // Insert a cast so we have the same type as the old alloc.
220  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
221  return success();
222  }
223 };
224 
225 /// Fold alloc operations with no users or only store and dealloc uses.
226 template <typename T>
227 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
229 
230  LogicalResult matchAndRewrite(T alloc,
231  PatternRewriter &rewriter) const override {
232  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
233  if (auto storeOp = dyn_cast<StoreOp>(op))
234  return storeOp.getValue() == alloc;
235  return !isa<DeallocOp>(op);
236  }))
237  return failure();
238 
239  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
240  rewriter.eraseOp(user);
241 
242  rewriter.eraseOp(alloc);
243  return success();
244  }
245 };
246 } // namespace
247 
248 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
249  MLIRContext *context) {
250  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
251 }
252 
253 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
254  MLIRContext *context) {
255  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
256  context);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // ReallocOp
261 //===----------------------------------------------------------------------===//
262 
263 LogicalResult ReallocOp::verify() {
264  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
265  MemRefType resultType = getType();
266 
267  // The source memref should have identity layout (or none).
268  if (!sourceType.getLayout().isIdentity())
269  return emitError("unsupported layout for source memref type ")
270  << sourceType;
271 
272  // The result memref should have identity layout (or none).
273  if (!resultType.getLayout().isIdentity())
274  return emitError("unsupported layout for result memref type ")
275  << resultType;
276 
277  // The source memref and the result memref should be in the same memory space.
278  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
279  return emitError("different memory spaces specified for source memref "
280  "type ")
281  << sourceType << " and result memref type " << resultType;
282 
283  // The source memref and the result memref should have the same element type.
284  if (sourceType.getElementType() != resultType.getElementType())
285  return emitError("different element types specified for source memref "
286  "type ")
287  << sourceType << " and result memref type " << resultType;
288 
289  // Verify that we have the dynamic dimension operand when it is needed.
290  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
291  return emitError("missing dimension operand for result type ")
292  << resultType;
293  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
294  return emitError("unnecessary dimension operand for result type ")
295  << resultType;
296 
297  return success();
298 }
299 
300 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
301  MLIRContext *context) {
302  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // AllocaScopeOp
307 //===----------------------------------------------------------------------===//
308 
310  bool printBlockTerminators = false;
311 
312  p << ' ';
313  if (!getResults().empty()) {
314  p << " -> (" << getResultTypes() << ")";
315  printBlockTerminators = true;
316  }
317  p << ' ';
318  p.printRegion(getBodyRegion(),
319  /*printEntryBlockArgs=*/false,
320  /*printBlockTerminators=*/printBlockTerminators);
321  p.printOptionalAttrDict((*this)->getAttrs());
322 }
323 
324 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
325  // Create a region for the body.
326  result.regions.reserve(1);
327  Region *bodyRegion = result.addRegion();
328 
329  // Parse optional results type list.
330  if (parser.parseOptionalArrowTypeList(result.types))
331  return failure();
332 
333  // Parse the body region.
334  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
335  return failure();
336  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
337  result.location);
338 
339  // Parse the optional attribute list.
340  if (parser.parseOptionalAttrDict(result.attributes))
341  return failure();
342 
343  return success();
344 }
345 
346 void AllocaScopeOp::getSuccessorRegions(
348  if (!point.isParent()) {
349  regions.push_back(RegionSuccessor(getResults()));
350  return;
351  }
352 
353  regions.push_back(RegionSuccessor(&getBodyRegion()));
354 }
355 
356 /// Given an operation, return whether this op is guaranteed to
357 /// allocate an AutomaticAllocationScopeResource
359  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
360  if (!interface)
361  return false;
362  for (auto res : op->getResults()) {
363  if (auto effect =
364  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
365  if (isa<SideEffects::AutomaticAllocationScopeResource>(
366  effect->getResource()))
367  return true;
368  }
369  }
370  return false;
371 }
372 
373 /// Given an operation, return whether this op itself could
374 /// allocate an AutomaticAllocationScopeResource. Note that
375 /// this will not check whether an operation contained within
376 /// the op can allocate.
378  // This op itself doesn't create a stack allocation,
379  // the inner allocation should be handled separately.
381  return false;
382  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
383  if (!interface)
384  return true;
385  for (auto res : op->getResults()) {
386  if (auto effect =
387  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
388  if (isa<SideEffects::AutomaticAllocationScopeResource>(
389  effect->getResource()))
390  return true;
391  }
392  }
393  return false;
394 }
395 
396 /// Return whether this op is the last non terminating op
397 /// in a region. That is to say, it is in a one-block region
398 /// and is only followed by a terminator. This prevents
399 /// extending the lifetime of allocations.
401  return op->getBlock()->mightHaveTerminator() &&
402  op->getNextNode() == op->getBlock()->getTerminator() &&
403  op->getParentRegion()->hasOneBlock();
404 }
405 
406 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
407 /// or it contains no allocation.
408 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
410 
411  LogicalResult matchAndRewrite(AllocaScopeOp op,
412  PatternRewriter &rewriter) const override {
413  bool hasPotentialAlloca =
414  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
415  if (alloc == op)
416  return WalkResult::advance();
418  return WalkResult::interrupt();
419  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
420  return WalkResult::skip();
421  return WalkResult::advance();
422  }).wasInterrupted();
423 
424  // If this contains no potential allocation, it is always legal to
425  // inline. Otherwise, consider two conditions:
426  if (hasPotentialAlloca) {
427  // If the parent isn't an allocation scope, or we are not the last
428  // non-terminator op in the parent, we will extend the lifetime.
429  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
430  return failure();
431  if (!lastNonTerminatorInRegion(op))
432  return failure();
433  }
434 
435  Block *block = &op.getRegion().front();
436  Operation *terminator = block->getTerminator();
437  ValueRange results = terminator->getOperands();
438  rewriter.inlineBlockBefore(block, op);
439  rewriter.replaceOp(op, results);
440  rewriter.eraseOp(terminator);
441  return success();
442  }
443 };
444 
445 /// Move allocations into an allocation scope, if it is legal to
446 /// move them (e.g. their operands are available at the location
447 /// the op would be moved to).
448 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
450 
451  LogicalResult matchAndRewrite(AllocaScopeOp op,
452  PatternRewriter &rewriter) const override {
453 
454  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
455  return failure();
456 
457  Operation *lastParentWithoutScope = op->getParentOp();
458 
459  if (!lastParentWithoutScope ||
460  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
461  return failure();
462 
463  // Only apply to if this is this last non-terminator
464  // op in the block (lest lifetime be extended) of a one
465  // block region
466  if (!lastNonTerminatorInRegion(op) ||
467  !lastNonTerminatorInRegion(lastParentWithoutScope))
468  return failure();
469 
470  while (!lastParentWithoutScope->getParentOp()
472  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
473  if (!lastParentWithoutScope ||
474  !lastNonTerminatorInRegion(lastParentWithoutScope))
475  return failure();
476  }
477  assert(lastParentWithoutScope->getParentOp()
479 
480  Region *containingRegion = nullptr;
481  for (auto &r : lastParentWithoutScope->getRegions()) {
482  if (r.isAncestor(op->getParentRegion())) {
483  assert(containingRegion == nullptr &&
484  "only one region can contain the op");
485  containingRegion = &r;
486  }
487  }
488  assert(containingRegion && "op must be contained in a region");
489 
490  SmallVector<Operation *> toHoist;
491  op->walk([&](Operation *alloc) {
493  return WalkResult::skip();
494 
495  // If any operand is not defined before the location of
496  // lastParentWithoutScope (i.e. where we would hoist to), skip.
497  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
498  return containingRegion->isAncestor(v.getParentRegion());
499  }))
500  return WalkResult::skip();
501  toHoist.push_back(alloc);
502  return WalkResult::advance();
503  });
504 
505  if (toHoist.empty())
506  return failure();
507  rewriter.setInsertionPoint(lastParentWithoutScope);
508  for (auto *op : toHoist) {
509  auto *cloned = rewriter.clone(*op);
510  rewriter.replaceOp(op, cloned->getResults());
511  }
512  return success();
513  }
514 };
515 
516 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
517  MLIRContext *context) {
518  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // AssumeAlignmentOp
523 //===----------------------------------------------------------------------===//
524 
525 LogicalResult AssumeAlignmentOp::verify() {
526  if (!llvm::isPowerOf2_32(getAlignment()))
527  return emitOpError("alignment must be power of 2");
528  return success();
529 }
530 
531 void AssumeAlignmentOp::getAsmResultNames(
532  function_ref<void(Value, StringRef)> setNameFn) {
533  setNameFn(getResult(), "assume_align");
534 }
535 
536 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
537  auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
538  if (!source)
539  return {};
540  if (source.getAlignment() != getAlignment())
541  return {};
542  return getMemref();
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // CastOp
547 //===----------------------------------------------------------------------===//
548 
549 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
550  setNameFn(getResult(), "cast");
551 }
552 
553 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
554 /// source memref. This is useful to fold a memref.cast into a consuming op
555 /// and implement canonicalization patterns for ops in different dialects that
556 /// may consume the results of memref.cast operations. Such foldable memref.cast
557 /// operations are typically inserted as `view` and `subview` ops are
558 /// canonicalized, to preserve the type compatibility of their uses.
559 ///
560 /// Returns true when all conditions are met:
561 /// 1. source and result are ranked memrefs with strided semantics and same
562 /// element type and rank.
563 /// 2. each of the source's size, offset or stride has more static information
564 /// than the corresponding result's size, offset or stride.
565 ///
566 /// Example 1:
567 /// ```mlir
568 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
569 /// %2 = consumer %1 ... : memref<?x?xf32> ...
570 /// ```
571 ///
572 /// may fold into:
573 ///
574 /// ```mlir
575 /// %2 = consumer %0 ... : memref<8x16xf32> ...
576 /// ```
577 ///
578 /// Example 2:
579 /// ```
580 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
581 /// to memref<?x?xf32>
582 /// consumer %1 : memref<?x?xf32> ...
583 /// ```
584 ///
585 /// may fold into:
586 ///
587 /// ```
588 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
589 /// ```
590 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
591  MemRefType sourceType =
592  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
593  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
594 
595  // Requires ranked MemRefType.
596  if (!sourceType || !resultType)
597  return false;
598 
599  // Requires same elemental type.
600  if (sourceType.getElementType() != resultType.getElementType())
601  return false;
602 
603  // Requires same rank.
604  if (sourceType.getRank() != resultType.getRank())
605  return false;
606 
607  // Only fold casts between strided memref forms.
608  int64_t sourceOffset, resultOffset;
609  SmallVector<int64_t, 4> sourceStrides, resultStrides;
610  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
611  failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
612  return false;
613 
614  // If cast is towards more static sizes along any dimension, don't fold.
615  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
616  auto ss = std::get<0>(it), st = std::get<1>(it);
617  if (ss != st)
618  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
619  return false;
620  }
621 
622  // If cast is towards more static offset along any dimension, don't fold.
623  if (sourceOffset != resultOffset)
624  if (ShapedType::isDynamic(sourceOffset) &&
625  ShapedType::isStatic(resultOffset))
626  return false;
627 
628  // If cast is towards more static strides along any dimension, don't fold.
629  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
630  auto ss = std::get<0>(it), st = std::get<1>(it);
631  if (ss != st)
632  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
633  return false;
634  }
635 
636  return true;
637 }
638 
639 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
640  if (inputs.size() != 1 || outputs.size() != 1)
641  return false;
642  Type a = inputs.front(), b = outputs.front();
643  auto aT = llvm::dyn_cast<MemRefType>(a);
644  auto bT = llvm::dyn_cast<MemRefType>(b);
645 
646  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
647  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
648 
649  if (aT && bT) {
650  if (aT.getElementType() != bT.getElementType())
651  return false;
652  if (aT.getLayout() != bT.getLayout()) {
653  int64_t aOffset, bOffset;
654  SmallVector<int64_t, 4> aStrides, bStrides;
655  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
656  failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
657  aStrides.size() != bStrides.size())
658  return false;
659 
660  // Strides along a dimension/offset are compatible if the value in the
661  // source memref is static and the value in the target memref is the
662  // same. They are also compatible if either one is dynamic (see
663  // description of MemRefCastOp for details).
664  auto checkCompatible = [](int64_t a, int64_t b) {
665  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
666  };
667  if (!checkCompatible(aOffset, bOffset))
668  return false;
669  for (const auto &aStride : enumerate(aStrides))
670  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
671  return false;
672  }
673  if (aT.getMemorySpace() != bT.getMemorySpace())
674  return false;
675 
676  // They must have the same rank, and any specified dimensions must match.
677  if (aT.getRank() != bT.getRank())
678  return false;
679 
680  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
681  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
682  if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
683  aDim != bDim)
684  return false;
685  }
686  return true;
687  } else {
688  if (!aT && !uaT)
689  return false;
690  if (!bT && !ubT)
691  return false;
692  // Unranked to unranked casting is unsupported
693  if (uaT && ubT)
694  return false;
695 
696  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
697  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
698  if (aEltType != bEltType)
699  return false;
700 
701  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
702  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
703  return aMemSpace == bMemSpace;
704  }
705 
706  return false;
707 }
708 
709 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
710  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // CopyOp
715 //===----------------------------------------------------------------------===//
716 
717 namespace {
718 
719 /// Fold memref.copy(%x, %x).
720 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
722 
723  LogicalResult matchAndRewrite(CopyOp copyOp,
724  PatternRewriter &rewriter) const override {
725  if (copyOp.getSource() != copyOp.getTarget())
726  return failure();
727 
728  rewriter.eraseOp(copyOp);
729  return success();
730  }
731 };
732 
733 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
735 
736  static bool isEmptyMemRef(BaseMemRefType type) {
737  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
738  }
739 
740  LogicalResult matchAndRewrite(CopyOp copyOp,
741  PatternRewriter &rewriter) const override {
742  if (isEmptyMemRef(copyOp.getSource().getType()) ||
743  isEmptyMemRef(copyOp.getTarget().getType())) {
744  rewriter.eraseOp(copyOp);
745  return success();
746  }
747 
748  return failure();
749  }
750 };
751 } // namespace
752 
753 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
754  MLIRContext *context) {
755  results.add<FoldEmptyCopy, FoldSelfCopy>(context);
756 }
757 
758 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
759 /// and element type, the cast can be skipped. Such CastOps only cast the layout
760 /// of the type.
761 static LogicalResult FoldCopyOfCast(CopyOp op) {
762  for (OpOperand &operand : op->getOpOperands()) {
763  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
764  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
765  operand.set(castOp.getOperand());
766  return success();
767  }
768  }
769  return failure();
770 }
771 
772 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
774 
775  /// copy(memrefcast) -> copy
776  return FoldCopyOfCast(*this);
777 }
778 
779 //===----------------------------------------------------------------------===//
780 // DeallocOp
781 //===----------------------------------------------------------------------===//
782 
783 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
785  /// dealloc(memrefcast) -> dealloc
786  return foldMemRefCast(*this);
787 }
788 
789 //===----------------------------------------------------------------------===//
790 // DimOp
791 //===----------------------------------------------------------------------===//
792 
793 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
794  setNameFn(getResult(), "dim");
795 }
796 
797 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
798  int64_t index) {
799  auto loc = result.location;
800  Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
801  build(builder, result, source, indexValue);
802 }
803 
804 std::optional<int64_t> DimOp::getConstantIndex() {
805  return getConstantIntValue(getIndex());
806 }
807 
808 Speculation::Speculatability DimOp::getSpeculatability() {
809  auto constantIndex = getConstantIndex();
810  if (!constantIndex)
812 
813  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
814  if (!rankedSourceType)
816 
817  if (rankedSourceType.getRank() <= constantIndex)
819 
821 }
822 
823 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
824  SetIntLatticeFn setResultRange) {
825  setResultRange(getResult(),
826  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
827 }
828 
829 /// Return a map with key being elements in `vals` and data being number of
830 /// occurences of it. Use std::map, since the `vals` here are strides and the
831 /// dynamic stride value is the same as the tombstone value for
832 /// `DenseMap<int64_t>`.
833 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
834  std::map<int64_t, unsigned> numOccurences;
835  for (auto val : vals)
836  numOccurences[val]++;
837  return numOccurences;
838 }
839 
840 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
841 /// to be a subset of `originalType` with some `1` entries erased, return the
842 /// set of indices that specifies which of the entries of `originalShape` are
843 /// dropped to obtain `reducedShape`.
844 /// This accounts for cases where there are multiple unit-dims, but only a
845 /// subset of those are dropped. For MemRefTypes these can be disambiguated
846 /// using the strides. If a dimension is dropped the stride must be dropped too.
847 static FailureOr<llvm::SmallBitVector>
848 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
849  ArrayRef<OpFoldResult> sizes) {
850  llvm::SmallBitVector unusedDims(originalType.getRank());
851  if (originalType.getRank() == reducedType.getRank())
852  return unusedDims;
853 
854  for (const auto &dim : llvm::enumerate(sizes))
855  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
856  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
857  unusedDims.set(dim.index());
858 
859  // Early exit for the case where the number of unused dims matches the number
860  // of ranks reduced.
861  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
862  originalType.getRank())
863  return unusedDims;
864 
865  SmallVector<int64_t> originalStrides, candidateStrides;
866  int64_t originalOffset, candidateOffset;
867  if (failed(
868  originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
869  failed(
870  reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
871  return failure();
872 
873  // For memrefs, a dimension is truly dropped if its corresponding stride is
874  // also dropped. This is particularly important when more than one of the dims
875  // is 1. Track the number of occurences of the strides in the original type
876  // and the candidate type. For each unused dim that stride should not be
877  // present in the candidate type. Note that there could be multiple dimensions
878  // that have the same size. We dont need to exactly figure out which dim
879  // corresponds to which stride, we just need to verify that the number of
880  // reptitions of a stride in the original + number of unused dims with that
881  // stride == number of repititions of a stride in the candidate.
882  std::map<int64_t, unsigned> currUnaccountedStrides =
883  getNumOccurences(originalStrides);
884  std::map<int64_t, unsigned> candidateStridesNumOccurences =
885  getNumOccurences(candidateStrides);
886  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
887  if (!unusedDims.test(dim))
888  continue;
889  int64_t originalStride = originalStrides[dim];
890  if (currUnaccountedStrides[originalStride] >
891  candidateStridesNumOccurences[originalStride]) {
892  // This dim can be treated as dropped.
893  currUnaccountedStrides[originalStride]--;
894  continue;
895  }
896  if (currUnaccountedStrides[originalStride] ==
897  candidateStridesNumOccurences[originalStride]) {
898  // The stride for this is not dropped. Keep as is.
899  unusedDims.reset(dim);
900  continue;
901  }
902  if (currUnaccountedStrides[originalStride] <
903  candidateStridesNumOccurences[originalStride]) {
904  // This should never happen. Cant have a stride in the reduced rank type
905  // that wasnt in the original one.
906  return failure();
907  }
908  }
909 
910  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
911  originalType.getRank())
912  return failure();
913  return unusedDims;
914 }
915 
916 llvm::SmallBitVector SubViewOp::getDroppedDims() {
917  MemRefType sourceType = getSourceType();
918  MemRefType resultType = getType();
919  FailureOr<llvm::SmallBitVector> unusedDims =
920  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
921  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
922  return *unusedDims;
923 }
924 
925 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
926  // All forms of folding require a known index.
927  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
928  if (!index)
929  return {};
930 
931  // Folding for unranked types (UnrankedMemRefType) is not supported.
932  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
933  if (!memrefType)
934  return {};
935 
936  // Out of bound indices produce undefined behavior but are still valid IR.
937  // Don't choke on them.
938  int64_t indexVal = index.getInt();
939  if (indexVal < 0 || indexVal >= memrefType.getRank())
940  return {};
941 
942  // Fold if the shape extent along the given index is known.
943  if (!memrefType.isDynamicDim(index.getInt())) {
944  Builder builder(getContext());
945  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
946  }
947 
948  // The size at the given index is now known to be a dynamic size.
949  unsigned unsignedIndex = index.getValue().getZExtValue();
950 
951  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
952  Operation *definingOp = getSource().getDefiningOp();
953 
954  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
955  return *(alloc.getDynamicSizes().begin() +
956  memrefType.getDynamicDimIndex(unsignedIndex));
957 
958  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
959  return *(alloca.getDynamicSizes().begin() +
960  memrefType.getDynamicDimIndex(unsignedIndex));
961 
962  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
963  return *(view.getDynamicSizes().begin() +
964  memrefType.getDynamicDimIndex(unsignedIndex));
965 
966  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
967  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
968  unsigned resultIndex = 0;
969  unsigned sourceRank = subview.getSourceType().getRank();
970  unsigned sourceIndex = 0;
971  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
972  if (unusedDims.test(i))
973  continue;
974  if (resultIndex == unsignedIndex) {
975  sourceIndex = i;
976  break;
977  }
978  resultIndex++;
979  }
980  assert(subview.isDynamicSize(sourceIndex) &&
981  "expected dynamic subview size");
982  return subview.getDynamicSize(sourceIndex);
983  }
984 
985  if (auto sizeInterface =
986  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
987  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
988  "Expected dynamic subview size");
989  return sizeInterface.getDynamicSize(unsignedIndex);
990  }
991 
992  // dim(memrefcast) -> dim
993  if (succeeded(foldMemRefCast(*this)))
994  return getResult();
995 
996  return {};
997 }
998 
999 namespace {
1000 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1001 /// operand.
1002 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1004 
1005  LogicalResult matchAndRewrite(DimOp dim,
1006  PatternRewriter &rewriter) const override {
1007  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1008 
1009  if (!reshape)
1010  return rewriter.notifyMatchFailure(
1011  dim, "Dim op is not defined by a reshape op.");
1012 
1013  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1014  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1015  // cheaply check that either of the following conditions hold:
1016  // 1. dim.getIndex() is defined in the same block as reshape but before
1017  // reshape.
1018  // 2. dim.getIndex() is defined in a parent block of
1019  // reshape.
1020 
1021  // Check condition 1
1022  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1023  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1024  if (reshape->isBeforeInBlock(definingOp)) {
1025  return rewriter.notifyMatchFailure(
1026  dim,
1027  "dim.getIndex is not defined before reshape in the same block.");
1028  }
1029  } // else dim.getIndex is a block argument to reshape->getBlock and
1030  // dominates reshape
1031  } // Check condition 2
1032  else if (dim->getBlock() != reshape->getBlock() &&
1033  !dim.getIndex().getParentRegion()->isProperAncestor(
1034  reshape->getParentRegion())) {
1035  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1036  // already know dim.getIndex() dominates reshape without calling
1037  // `isProperAncestor`
1038  return rewriter.notifyMatchFailure(
1039  dim, "dim.getIndex does not dominate reshape.");
1040  }
1041 
1042  // Place the load directly after the reshape to ensure that the shape memref
1043  // was not mutated.
1044  rewriter.setInsertionPointAfter(reshape);
1045  Location loc = dim.getLoc();
1046  Value load =
1047  LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1048  if (load.getType() != dim.getType())
1049  load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1050  rewriter.replaceOp(dim, load);
1051  return success();
1052  }
1053 };
1054 
1055 } // namespace
1056 
1057 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1058  MLIRContext *context) {
1059  results.add<DimOfMemRefReshape>(context);
1060 }
1061 
1062 // ---------------------------------------------------------------------------
1063 // DmaStartOp
1064 // ---------------------------------------------------------------------------
1065 
1066 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1067  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1068  ValueRange destIndices, Value numElements,
1069  Value tagMemRef, ValueRange tagIndices, Value stride,
1070  Value elementsPerStride) {
1071  result.addOperands(srcMemRef);
1072  result.addOperands(srcIndices);
1073  result.addOperands(destMemRef);
1074  result.addOperands(destIndices);
1075  result.addOperands({numElements, tagMemRef});
1076  result.addOperands(tagIndices);
1077  if (stride)
1078  result.addOperands({stride, elementsPerStride});
1079 }
1080 
1082  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1083  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1084  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1085  if (isStrided())
1086  p << ", " << getStride() << ", " << getNumElementsPerStride();
1087 
1088  p.printOptionalAttrDict((*this)->getAttrs());
1089  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1090  << ", " << getTagMemRef().getType();
1091 }
1092 
1093 // Parse DmaStartOp.
1094 // Ex:
1095 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1096 // %tag[%index], %stride, %num_elt_per_stride :
1097 // : memref<3076 x f32, 0>,
1098 // memref<1024 x f32, 2>,
1099 // memref<1 x i32>
1100 //
1101 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1102  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1104  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1106  OpAsmParser::UnresolvedOperand numElementsInfo;
1107  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1110 
1111  SmallVector<Type, 3> types;
1112  auto indexType = parser.getBuilder().getIndexType();
1113 
1114  // Parse and resolve the following list of operands:
1115  // *) source memref followed by its indices (in square brackets).
1116  // *) destination memref followed by its indices (in square brackets).
1117  // *) dma size in KiB.
1118  if (parser.parseOperand(srcMemRefInfo) ||
1119  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1120  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1121  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1122  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1123  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1124  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1125  return failure();
1126 
1127  // Parse optional stride and elements per stride.
1128  if (parser.parseTrailingOperandList(strideInfo))
1129  return failure();
1130 
1131  bool isStrided = strideInfo.size() == 2;
1132  if (!strideInfo.empty() && !isStrided) {
1133  return parser.emitError(parser.getNameLoc(),
1134  "expected two stride related operands");
1135  }
1136 
1137  if (parser.parseColonTypeList(types))
1138  return failure();
1139  if (types.size() != 3)
1140  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1141 
1142  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1143  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1144  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1145  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1146  // size should be an index.
1147  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1148  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1149  // tag indices should be index.
1150  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1151  return failure();
1152 
1153  if (isStrided) {
1154  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1155  return failure();
1156  }
1157 
1158  return success();
1159 }
1160 
1161 LogicalResult DmaStartOp::verify() {
1162  unsigned numOperands = getNumOperands();
1163 
1164  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1165  // the number of elements.
1166  if (numOperands < 4)
1167  return emitOpError("expected at least 4 operands");
1168 
1169  // Check types of operands. The order of these calls is important: the later
1170  // calls rely on some type properties to compute the operand position.
1171  // 1. Source memref.
1172  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1173  return emitOpError("expected source to be of memref type");
1174  if (numOperands < getSrcMemRefRank() + 4)
1175  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1176  << " operands";
1177  if (!getSrcIndices().empty() &&
1178  !llvm::all_of(getSrcIndices().getTypes(),
1179  [](Type t) { return t.isIndex(); }))
1180  return emitOpError("expected source indices to be of index type");
1181 
1182  // 2. Destination memref.
1183  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1184  return emitOpError("expected destination to be of memref type");
1185  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1186  if (numOperands < numExpectedOperands)
1187  return emitOpError() << "expected at least " << numExpectedOperands
1188  << " operands";
1189  if (!getDstIndices().empty() &&
1190  !llvm::all_of(getDstIndices().getTypes(),
1191  [](Type t) { return t.isIndex(); }))
1192  return emitOpError("expected destination indices to be of index type");
1193 
1194  // 3. Number of elements.
1195  if (!getNumElements().getType().isIndex())
1196  return emitOpError("expected num elements to be of index type");
1197 
1198  // 4. Tag memref.
1199  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1200  return emitOpError("expected tag to be of memref type");
1201  numExpectedOperands += getTagMemRefRank();
1202  if (numOperands < numExpectedOperands)
1203  return emitOpError() << "expected at least " << numExpectedOperands
1204  << " operands";
1205  if (!getTagIndices().empty() &&
1206  !llvm::all_of(getTagIndices().getTypes(),
1207  [](Type t) { return t.isIndex(); }))
1208  return emitOpError("expected tag indices to be of index type");
1209 
1210  // Optional stride-related operands must be either both present or both
1211  // absent.
1212  if (numOperands != numExpectedOperands &&
1213  numOperands != numExpectedOperands + 2)
1214  return emitOpError("incorrect number of operands");
1215 
1216  // 5. Strides.
1217  if (isStrided()) {
1218  if (!getStride().getType().isIndex() ||
1219  !getNumElementsPerStride().getType().isIndex())
1220  return emitOpError(
1221  "expected stride and num elements per stride to be of type index");
1222  }
1223 
1224  return success();
1225 }
1226 
1227 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1228  SmallVectorImpl<OpFoldResult> &results) {
1229  /// dma_start(memrefcast) -> dma_start
1230  return foldMemRefCast(*this);
1231 }
1232 
1233 // ---------------------------------------------------------------------------
1234 // DmaWaitOp
1235 // ---------------------------------------------------------------------------
1236 
1237 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1238  SmallVectorImpl<OpFoldResult> &results) {
1239  /// dma_wait(memrefcast) -> dma_wait
1240  return foldMemRefCast(*this);
1241 }
1242 
1243 LogicalResult DmaWaitOp::verify() {
1244  // Check that the number of tag indices matches the tagMemRef rank.
1245  unsigned numTagIndices = getTagIndices().size();
1246  unsigned tagMemRefRank = getTagMemRefRank();
1247  if (numTagIndices != tagMemRefRank)
1248  return emitOpError() << "expected tagIndices to have the same number of "
1249  "elements as the tagMemRef rank, expected "
1250  << tagMemRefRank << ", but got " << numTagIndices;
1251  return success();
1252 }
1253 
1254 //===----------------------------------------------------------------------===//
1255 // ExtractAlignedPointerAsIndexOp
1256 //===----------------------------------------------------------------------===//
1257 
1258 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1259  function_ref<void(Value, StringRef)> setNameFn) {
1260  setNameFn(getResult(), "intptr");
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // ExtractStridedMetadataOp
1265 //===----------------------------------------------------------------------===//
1266 
1267 /// The number and type of the results are inferred from the
1268 /// shape of the source.
1269 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1270  MLIRContext *context, std::optional<Location> location,
1271  ExtractStridedMetadataOp::Adaptor adaptor,
1272  SmallVectorImpl<Type> &inferredReturnTypes) {
1273  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1274  if (!sourceType)
1275  return failure();
1276 
1277  unsigned sourceRank = sourceType.getRank();
1278  IndexType indexType = IndexType::get(context);
1279  auto memrefType =
1280  MemRefType::get({}, sourceType.getElementType(),
1281  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1282  // Base.
1283  inferredReturnTypes.push_back(memrefType);
1284  // Offset.
1285  inferredReturnTypes.push_back(indexType);
1286  // Sizes and strides.
1287  for (unsigned i = 0; i < sourceRank * 2; ++i)
1288  inferredReturnTypes.push_back(indexType);
1289  return success();
1290 }
1291 
1292 void ExtractStridedMetadataOp::getAsmResultNames(
1293  function_ref<void(Value, StringRef)> setNameFn) {
1294  setNameFn(getBaseBuffer(), "base_buffer");
1295  setNameFn(getOffset(), "offset");
1296  // For multi-result to work properly with pretty names and packed syntax `x:3`
1297  // we can only give a pretty name to the first value in the pack.
1298  if (!getSizes().empty()) {
1299  setNameFn(getSizes().front(), "sizes");
1300  setNameFn(getStrides().front(), "strides");
1301  }
1302 }
1303 
1304 /// Helper function to perform the replacement of all constant uses of `values`
1305 /// by a materialized constant extracted from `maybeConstants`.
1306 /// `values` and `maybeConstants` are expected to have the same size.
1307 template <typename Container>
1308 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1309  Container values,
1310  ArrayRef<OpFoldResult> maybeConstants) {
1311  assert(values.size() == maybeConstants.size() &&
1312  " expected values and maybeConstants of the same size");
1313  bool atLeastOneReplacement = false;
1314  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1315  // Don't materialize a constant if there are no uses: this would indice
1316  // infinite loops in the driver.
1317  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1318  continue;
1319  assert(isa<Attribute>(maybeConstant) &&
1320  "The constified value should be either unchanged (i.e., == result) "
1321  "or a constant");
1322  Value constantVal = arith::ConstantIndexOp::create(
1323  rewriter, loc,
1324  llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1325  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1326  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1327  // yet.
1328  op->replaceUsesOfWith(result, constantVal);
1329  atLeastOneReplacement = true;
1330  }
1331  }
1332  return atLeastOneReplacement;
1333 }
1334 
1335 LogicalResult
1336 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1337  SmallVectorImpl<OpFoldResult> &results) {
1338  OpBuilder builder(*this);
1339 
1340  bool atLeastOneReplacement = replaceConstantUsesOf(
1341  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1342  getConstifiedMixedOffset());
1343  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1344  getConstifiedMixedSizes());
1345  atLeastOneReplacement |= replaceConstantUsesOf(
1346  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1347 
1348  return success(atLeastOneReplacement);
1349 }
1350 
1351 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1352  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1353  constifyIndexValues(values, getSource().getType().getShape());
1354  return values;
1355 }
1356 
1358 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1359  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1360  SmallVector<int64_t> staticValues;
1361  int64_t unused;
1362  LogicalResult status =
1363  getSource().getType().getStridesAndOffset(staticValues, unused);
1364  (void)status;
1365  assert(succeeded(status) && "could not get strides from type");
1366  constifyIndexValues(values, staticValues);
1367  return values;
1368 }
1369 
1370 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1371  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1372  SmallVector<OpFoldResult> values(1, offsetOfr);
1373  SmallVector<int64_t> staticValues, unused;
1374  int64_t offset;
1375  LogicalResult status =
1376  getSource().getType().getStridesAndOffset(unused, offset);
1377  (void)status;
1378  assert(succeeded(status) && "could not get offset from type");
1379  staticValues.push_back(offset);
1380  constifyIndexValues(values, staticValues);
1381  return values[0];
1382 }
1383 
1384 //===----------------------------------------------------------------------===//
1385 // GenericAtomicRMWOp
1386 //===----------------------------------------------------------------------===//
1387 
1388 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1389  Value memref, ValueRange ivs) {
1390  OpBuilder::InsertionGuard g(builder);
1391  result.addOperands(memref);
1392  result.addOperands(ivs);
1393 
1394  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1395  Type elementType = memrefType.getElementType();
1396  result.addTypes(elementType);
1397 
1398  Region *bodyRegion = result.addRegion();
1399  builder.createBlock(bodyRegion);
1400  bodyRegion->addArgument(elementType, memref.getLoc());
1401  }
1402 }
1403 
1404 LogicalResult GenericAtomicRMWOp::verify() {
1405  auto &body = getRegion();
1406  if (body.getNumArguments() != 1)
1407  return emitOpError("expected single number of entry block arguments");
1408 
1409  if (getResult().getType() != body.getArgument(0).getType())
1410  return emitOpError("expected block argument of the same type result type");
1411 
1412  bool hasSideEffects =
1413  body.walk([&](Operation *nestedOp) {
1414  if (isMemoryEffectFree(nestedOp))
1415  return WalkResult::advance();
1416  nestedOp->emitError(
1417  "body of 'memref.generic_atomic_rmw' should contain "
1418  "only operations with no side effects");
1419  return WalkResult::interrupt();
1420  })
1421  .wasInterrupted();
1422  return hasSideEffects ? failure() : success();
1423 }
1424 
1425 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1426  OperationState &result) {
1428  Type memrefType;
1430 
1431  Type indexType = parser.getBuilder().getIndexType();
1432  if (parser.parseOperand(memref) ||
1434  parser.parseColonType(memrefType) ||
1435  parser.resolveOperand(memref, memrefType, result.operands) ||
1436  parser.resolveOperands(ivs, indexType, result.operands))
1437  return failure();
1438 
1439  Region *body = result.addRegion();
1440  if (parser.parseRegion(*body, {}) ||
1441  parser.parseOptionalAttrDict(result.attributes))
1442  return failure();
1443  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1444  return success();
1445 }
1446 
1448  p << ' ' << getMemref() << "[" << getIndices()
1449  << "] : " << getMemref().getType() << ' ';
1450  p.printRegion(getRegion());
1451  p.printOptionalAttrDict((*this)->getAttrs());
1452 }
1453 
1454 //===----------------------------------------------------------------------===//
1455 // AtomicYieldOp
1456 //===----------------------------------------------------------------------===//
1457 
1458 LogicalResult AtomicYieldOp::verify() {
1459  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1460  Type resultType = getResult().getType();
1461  if (parentType != resultType)
1462  return emitOpError() << "types mismatch between yield op: " << resultType
1463  << " and its parent: " << parentType;
1464  return success();
1465 }
1466 
1467 //===----------------------------------------------------------------------===//
1468 // GlobalOp
1469 //===----------------------------------------------------------------------===//
1470 
1472  TypeAttr type,
1473  Attribute initialValue) {
1474  p << type;
1475  if (!op.isExternal()) {
1476  p << " = ";
1477  if (op.isUninitialized())
1478  p << "uninitialized";
1479  else
1480  p.printAttributeWithoutType(initialValue);
1481  }
1482 }
1483 
1484 static ParseResult
1486  Attribute &initialValue) {
1487  Type type;
1488  if (parser.parseType(type))
1489  return failure();
1490 
1491  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1492  if (!memrefType || !memrefType.hasStaticShape())
1493  return parser.emitError(parser.getNameLoc())
1494  << "type should be static shaped memref, but got " << type;
1495  typeAttr = TypeAttr::get(type);
1496 
1497  if (parser.parseOptionalEqual())
1498  return success();
1499 
1500  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1501  initialValue = UnitAttr::get(parser.getContext());
1502  return success();
1503  }
1504 
1505  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1506  if (parser.parseAttribute(initialValue, tensorType))
1507  return failure();
1508  if (!llvm::isa<ElementsAttr>(initialValue))
1509  return parser.emitError(parser.getNameLoc())
1510  << "initial value should be a unit or elements attribute";
1511  return success();
1512 }
1513 
1514 LogicalResult GlobalOp::verify() {
1515  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1516  if (!memrefType || !memrefType.hasStaticShape())
1517  return emitOpError("type should be static shaped memref, but got ")
1518  << getType();
1519 
1520  // Verify that the initial value, if present, is either a unit attribute or
1521  // an elements attribute.
1522  if (getInitialValue().has_value()) {
1523  Attribute initValue = getInitialValue().value();
1524  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1525  return emitOpError("initial value should be a unit or elements "
1526  "attribute, but got ")
1527  << initValue;
1528 
1529  // Check that the type of the initial value is compatible with the type of
1530  // the global variable.
1531  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1532  // Check the element types match.
1533  auto initElementType =
1534  cast<TensorType>(elementsAttr.getType()).getElementType();
1535  auto memrefElementType = memrefType.getElementType();
1536 
1537  if (initElementType != memrefElementType)
1538  return emitOpError("initial value element expected to be of type ")
1539  << memrefElementType << ", but was of type " << initElementType;
1540 
1541  // Check the shapes match, given that memref globals can only produce
1542  // statically shaped memrefs and elements literal type must have a static
1543  // shape we can assume both types are shaped.
1544  auto initShape = elementsAttr.getShapedType().getShape();
1545  auto memrefShape = memrefType.getShape();
1546  if (initShape != memrefShape)
1547  return emitOpError("initial value shape expected to be ")
1548  << memrefShape << " but was " << initShape;
1549  }
1550  }
1551 
1552  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1553  uint64_t alignment = *alignAttr;
1554 
1555  if (!llvm::isPowerOf2_64(alignment))
1556  return emitError() << "alignment attribute value " << alignment
1557  << " is not a power of 2";
1558  }
1559 
1560  // TODO: verify visibility for declarations.
1561  return success();
1562 }
1563 
1564 ElementsAttr GlobalOp::getConstantInitValue() {
1565  auto initVal = getInitialValue();
1566  if (getConstant() && initVal.has_value())
1567  return llvm::cast<ElementsAttr>(initVal.value());
1568  return {};
1569 }
1570 
1571 //===----------------------------------------------------------------------===//
1572 // GetGlobalOp
1573 //===----------------------------------------------------------------------===//
1574 
1575 LogicalResult
1576 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1577  // Verify that the result type is same as the type of the referenced
1578  // memref.global op.
1579  auto global =
1580  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1581  if (!global)
1582  return emitOpError("'")
1583  << getName() << "' does not reference a valid global memref";
1584 
1585  Type resultType = getResult().getType();
1586  if (global.getType() != resultType)
1587  return emitOpError("result type ")
1588  << resultType << " does not match type " << global.getType()
1589  << " of the global memref @" << getName();
1590  return success();
1591 }
1592 
1593 //===----------------------------------------------------------------------===//
1594 // LoadOp
1595 //===----------------------------------------------------------------------===//
1596 
1597 LogicalResult LoadOp::verify() {
1598  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1599  return emitOpError("incorrect number of indices for load, expected ")
1600  << getMemRefType().getRank() << " but got " << getIndices().size();
1601  }
1602  return success();
1603 }
1604 
1605 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1606  /// load(memrefcast) -> load
1607  if (succeeded(foldMemRefCast(*this)))
1608  return getResult();
1609  return OpFoldResult();
1610 }
1611 
1612 //===----------------------------------------------------------------------===//
1613 // MemorySpaceCastOp
1614 //===----------------------------------------------------------------------===//
1615 
1616 void MemorySpaceCastOp::getAsmResultNames(
1617  function_ref<void(Value, StringRef)> setNameFn) {
1618  setNameFn(getResult(), "memspacecast");
1619 }
1620 
1621 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1622  if (inputs.size() != 1 || outputs.size() != 1)
1623  return false;
1624  Type a = inputs.front(), b = outputs.front();
1625  auto aT = llvm::dyn_cast<MemRefType>(a);
1626  auto bT = llvm::dyn_cast<MemRefType>(b);
1627 
1628  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1629  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1630 
1631  if (aT && bT) {
1632  if (aT.getElementType() != bT.getElementType())
1633  return false;
1634  if (aT.getLayout() != bT.getLayout())
1635  return false;
1636  if (aT.getShape() != bT.getShape())
1637  return false;
1638  return true;
1639  }
1640  if (uaT && ubT) {
1641  return uaT.getElementType() == ubT.getElementType();
1642  }
1643  return false;
1644 }
1645 
1646 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1647  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1648  // t2)
1649  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1650  getSourceMutable().assign(parentCast.getSource());
1651  return getResult();
1652  }
1653  return Value{};
1654 }
1655 
1656 //===----------------------------------------------------------------------===//
1657 // PrefetchOp
1658 //===----------------------------------------------------------------------===//
1659 
1661  p << " " << getMemref() << '[';
1663  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1664  p << ", locality<" << getLocalityHint();
1665  p << ">, " << (getIsDataCache() ? "data" : "instr");
1667  (*this)->getAttrs(),
1668  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1669  p << " : " << getMemRefType();
1670 }
1671 
1672 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1673  OpAsmParser::UnresolvedOperand memrefInfo;
1675  IntegerAttr localityHint;
1676  MemRefType type;
1677  StringRef readOrWrite, cacheType;
1678 
1679  auto indexTy = parser.getBuilder().getIndexType();
1680  auto i32Type = parser.getBuilder().getIntegerType(32);
1681  if (parser.parseOperand(memrefInfo) ||
1682  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1683  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1684  parser.parseComma() || parser.parseKeyword("locality") ||
1685  parser.parseLess() ||
1686  parser.parseAttribute(localityHint, i32Type, "localityHint",
1687  result.attributes) ||
1688  parser.parseGreater() || parser.parseComma() ||
1689  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1690  parser.resolveOperand(memrefInfo, type, result.operands) ||
1691  parser.resolveOperands(indexInfo, indexTy, result.operands))
1692  return failure();
1693 
1694  if (readOrWrite != "read" && readOrWrite != "write")
1695  return parser.emitError(parser.getNameLoc(),
1696  "rw specifier has to be 'read' or 'write'");
1697  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1698  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1699 
1700  if (cacheType != "data" && cacheType != "instr")
1701  return parser.emitError(parser.getNameLoc(),
1702  "cache type has to be 'data' or 'instr'");
1703 
1704  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1705  parser.getBuilder().getBoolAttr(cacheType == "data"));
1706 
1707  return success();
1708 }
1709 
1710 LogicalResult PrefetchOp::verify() {
1711  if (getNumOperands() != 1 + getMemRefType().getRank())
1712  return emitOpError("too few indices");
1713 
1714  return success();
1715 }
1716 
1717 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1718  SmallVectorImpl<OpFoldResult> &results) {
1719  // prefetch(memrefcast) -> prefetch
1720  return foldMemRefCast(*this);
1721 }
1722 
1723 //===----------------------------------------------------------------------===//
1724 // RankOp
1725 //===----------------------------------------------------------------------===//
1726 
1727 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1728  // Constant fold rank when the rank of the operand is known.
1729  auto type = getOperand().getType();
1730  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1731  if (shapedType && shapedType.hasRank())
1732  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1733  return IntegerAttr();
1734 }
1735 
1736 //===----------------------------------------------------------------------===//
1737 // ReinterpretCastOp
1738 //===----------------------------------------------------------------------===//
1739 
1740 void ReinterpretCastOp::getAsmResultNames(
1741  function_ref<void(Value, StringRef)> setNameFn) {
1742  setNameFn(getResult(), "reinterpret_cast");
1743 }
1744 
1745 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1746 /// `staticSizes` and `staticStrides` are automatically filled with
1747 /// source-memref-rank sentinel values that encode dynamic entries.
1748 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1749  MemRefType resultType, Value source,
1750  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1751  ArrayRef<OpFoldResult> strides,
1752  ArrayRef<NamedAttribute> attrs) {
1753  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1754  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1755  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1756  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1757  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1758  result.addAttributes(attrs);
1759  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1760  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1761  b.getDenseI64ArrayAttr(staticSizes),
1762  b.getDenseI64ArrayAttr(staticStrides));
1763 }
1764 
1765 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1766  Value source, OpFoldResult offset,
1767  ArrayRef<OpFoldResult> sizes,
1768  ArrayRef<OpFoldResult> strides,
1769  ArrayRef<NamedAttribute> attrs) {
1770  auto sourceType = cast<BaseMemRefType>(source.getType());
1771  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1772  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1773  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1774  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1775  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1776  auto stridedLayout = StridedLayoutAttr::get(
1777  b.getContext(), staticOffsets.front(), staticStrides);
1778  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1779  stridedLayout, sourceType.getMemorySpace());
1780  build(b, result, resultType, source, offset, sizes, strides, attrs);
1781 }
1782 
1783 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1784  MemRefType resultType, Value source,
1785  int64_t offset, ArrayRef<int64_t> sizes,
1786  ArrayRef<int64_t> strides,
1787  ArrayRef<NamedAttribute> attrs) {
1788  SmallVector<OpFoldResult> sizeValues =
1789  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1790  return b.getI64IntegerAttr(v);
1791  }));
1792  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1793  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1794  return b.getI64IntegerAttr(v);
1795  }));
1796  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1797  strideValues, attrs);
1798 }
1799 
1800 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1801  MemRefType resultType, Value source, Value offset,
1802  ValueRange sizes, ValueRange strides,
1803  ArrayRef<NamedAttribute> attrs) {
1804  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1805  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1806  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1807  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1808  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1809 }
1810 
1811 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1812 // completed automatically, like we have for subview and extract_slice.
1813 LogicalResult ReinterpretCastOp::verify() {
1814  // The source and result memrefs should be in the same memory space.
1815  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1816  auto resultType = llvm::cast<MemRefType>(getType());
1817  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1818  return emitError("different memory spaces specified for source type ")
1819  << srcType << " and result memref type " << resultType;
1820  if (srcType.getElementType() != resultType.getElementType())
1821  return emitError("different element types specified for source type ")
1822  << srcType << " and result memref type " << resultType;
1823 
1824  // Match sizes in result memref type and in static_sizes attribute.
1825  for (auto [idx, resultSize, expectedSize] :
1826  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1827  if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1828  return emitError("expected result type with size = ")
1829  << (ShapedType::isDynamic(expectedSize)
1830  ? std::string("dynamic")
1831  : std::to_string(expectedSize))
1832  << " instead of " << resultSize << " in dim = " << idx;
1833  }
1834 
1835  // Match offset and strides in static_offset and static_strides attributes. If
1836  // result memref type has no affine map specified, this will assume an
1837  // identity layout.
1838  int64_t resultOffset;
1839  SmallVector<int64_t, 4> resultStrides;
1840  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1841  return emitError("expected result type to have strided layout but found ")
1842  << resultType;
1843 
1844  // Match offset in result memref type and in static_offsets attribute.
1845  int64_t expectedOffset = getStaticOffsets().front();
1846  if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1847  return emitError("expected result type with offset = ")
1848  << (ShapedType::isDynamic(expectedOffset)
1849  ? std::string("dynamic")
1850  : std::to_string(expectedOffset))
1851  << " instead of " << resultOffset;
1852 
1853  // Match strides in result memref type and in static_strides attribute.
1854  for (auto [idx, resultStride, expectedStride] :
1855  llvm::enumerate(resultStrides, getStaticStrides())) {
1856  if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1857  return emitError("expected result type with stride = ")
1858  << (ShapedType::isDynamic(expectedStride)
1859  ? std::string("dynamic")
1860  : std::to_string(expectedStride))
1861  << " instead of " << resultStride << " in dim = " << idx;
1862  }
1863 
1864  return success();
1865 }
1866 
1867 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1868  Value src = getSource();
1869  auto getPrevSrc = [&]() -> Value {
1870  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1871  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1872  return prev.getSource();
1873 
1874  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1875  if (auto prev = src.getDefiningOp<CastOp>())
1876  return prev.getSource();
1877 
1878  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1879  // are 0.
1880  if (auto prev = src.getDefiningOp<SubViewOp>())
1881  if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
1882  return prev.getSource();
1883 
1884  return nullptr;
1885  };
1886 
1887  if (auto prevSrc = getPrevSrc()) {
1888  getSourceMutable().assign(prevSrc);
1889  return getResult();
1890  }
1891 
1892  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1893  if (ShapedType::isStaticShape(getType().getShape()) &&
1894  src.getType() == getType() && getStaticOffsets().front() == 0) {
1895  return src;
1896  }
1897 
1898  return nullptr;
1899 }
1900 
1901 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1903  constifyIndexValues(values, getType().getShape());
1904  return values;
1905 }
1906 
1907 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1908  SmallVector<OpFoldResult> values = getMixedStrides();
1909  SmallVector<int64_t> staticValues;
1910  int64_t unused;
1911  LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
1912  (void)status;
1913  assert(succeeded(status) && "could not get strides from type");
1914  constifyIndexValues(values, staticValues);
1915  return values;
1916 }
1917 
1918 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1919  SmallVector<OpFoldResult> values = getMixedOffsets();
1920  assert(values.size() == 1 &&
1921  "reinterpret_cast must have one and only one offset");
1922  SmallVector<int64_t> staticValues, unused;
1923  int64_t offset;
1924  LogicalResult status = getType().getStridesAndOffset(unused, offset);
1925  (void)status;
1926  assert(succeeded(status) && "could not get offset from type");
1927  staticValues.push_back(offset);
1928  constifyIndexValues(values, staticValues);
1929  return values[0];
1930 }
1931 
1932 namespace {
1933 /// Replace the sequence:
1934 /// ```
1935 /// base, offset, sizes, strides = extract_strided_metadata src
1936 /// dst = reinterpret_cast base to offset, sizes, strides
1937 /// ```
1938 /// With
1939 ///
1940 /// ```
1941 /// dst = memref.cast src
1942 /// ```
1943 ///
1944 /// Note: The cast operation is only inserted when the type of dst and src
1945 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
1946 ///
1947 /// This pattern also matches when the offset, sizes, and strides don't come
1948 /// directly from the `extract_strided_metadata`'s results but it can be
1949 /// statically proven that they would hold the same values.
1950 ///
1951 /// For instance, the following sequence would be replaced:
1952 /// ```
1953 /// base, offset, sizes, strides =
1954 /// extract_strided_metadata memref : memref<3x4xty>
1955 /// dst = reinterpret_cast base to 0, [3, 4], strides
1956 /// ```
1957 /// Because we know (thanks to the type of the input memref) that variable
1958 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
1959 ///
1960 /// Similarly, the following sequence would be replaced:
1961 /// ```
1962 /// c0 = arith.constant 0
1963 /// c4 = arith.constant 4
1964 /// base, offset, sizes, strides =
1965 /// extract_strided_metadata memref : memref<3x4xty>
1966 /// dst = reinterpret_cast base to c0, [3, c4], strides
1967 /// ```
1968 /// Because we know that `offset`and `c0` will hold 0
1969 /// and `c4` will hold 4.
1970 ///
1971 /// If the pattern above does not match, the input of the
1972 /// extract_strided_metadata is always folded into the input of the
1973 /// reinterpret_cast operator. This allows for dead code elimination to get rid
1974 /// of the extract_strided_metadata in some cases.
1975 struct ReinterpretCastOpExtractStridedMetadataFolder
1976  : public OpRewritePattern<ReinterpretCastOp> {
1977 public:
1979 
1980  LogicalResult matchAndRewrite(ReinterpretCastOp op,
1981  PatternRewriter &rewriter) const override {
1982  auto extractStridedMetadata =
1983  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
1984  if (!extractStridedMetadata)
1985  return failure();
1986 
1987  // Check if the reinterpret cast reconstructs a memref with the exact same
1988  // properties as the extract strided metadata.
1989  auto isReinterpretCastNoop = [&]() -> bool {
1990  // First, check that the strides are the same.
1991  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
1992  op.getConstifiedMixedStrides()))
1993  return false;
1994 
1995  // Second, check the sizes.
1996  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
1997  op.getConstifiedMixedSizes()))
1998  return false;
1999 
2000  // Finally, check the offset.
2001  assert(op.getMixedOffsets().size() == 1 &&
2002  "reinterpret_cast with more than one offset should have been "
2003  "rejected by the verifier");
2004  return extractStridedMetadata.getConstifiedMixedOffset() ==
2005  op.getConstifiedMixedOffset();
2006  };
2007 
2008  if (!isReinterpretCastNoop()) {
2009  // If the extract_strided_metadata / reinterpret_cast pair can't be
2010  // completely folded, then we could fold the input of the
2011  // extract_strided_metadata into the input of the reinterpret_cast
2012  // input. For some cases (e.g., static dimensions) the
2013  // the extract_strided_metadata is eliminated by dead code elimination.
2014  //
2015  // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2016  //
2017  // We can always fold the input of a extract_strided_metadata operator
2018  // to the input of a reinterpret_cast operator, because they point to
2019  // the same memory. Note that the reinterpret_cast does not use the
2020  // layout of its input memref, only its base memory pointer which is
2021  // the same as the base pointer returned by the extract_strided_metadata
2022  // operator and the base pointer of the extract_strided_metadata memref
2023  // input.
2024  rewriter.modifyOpInPlace(op, [&]() {
2025  op.getSourceMutable().assign(extractStridedMetadata.getSource());
2026  });
2027  return success();
2028  }
2029 
2030  // At this point, we know that the back and forth between extract strided
2031  // metadata and reinterpret cast is a noop. However, the final type of the
2032  // reinterpret cast may not be exactly the same as the original memref.
2033  // E.g., it could be changing a dimension from static to dynamic. Check that
2034  // here and add a cast if necessary.
2035  Type srcTy = extractStridedMetadata.getSource().getType();
2036  if (srcTy == op.getResult().getType())
2037  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2038  else
2039  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2040  extractStridedMetadata.getSource());
2041 
2042  return success();
2043  }
2044 };
2045 } // namespace
2046 
2047 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2048  MLIRContext *context) {
2049  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2050 }
2051 
2052 //===----------------------------------------------------------------------===//
2053 // Reassociative reshape ops
2054 //===----------------------------------------------------------------------===//
2055 
2056 void CollapseShapeOp::getAsmResultNames(
2057  function_ref<void(Value, StringRef)> setNameFn) {
2058  setNameFn(getResult(), "collapse_shape");
2059 }
2060 
2061 void ExpandShapeOp::getAsmResultNames(
2062  function_ref<void(Value, StringRef)> setNameFn) {
2063  setNameFn(getResult(), "expand_shape");
2064 }
2065 
2066 LogicalResult ExpandShapeOp::reifyResultShapes(
2067  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2068  reifiedResultShapes = {
2069  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2070  return success();
2071 }
2072 
2073 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2074 /// result and operand. Layout maps are verified separately.
2075 ///
2076 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2077 /// allowed in a reassocation group.
2078 static LogicalResult
2080  ArrayRef<int64_t> expandedShape,
2081  ArrayRef<ReassociationIndices> reassociation,
2082  bool allowMultipleDynamicDimsPerGroup) {
2083  // There must be one reassociation group per collapsed dimension.
2084  if (collapsedShape.size() != reassociation.size())
2085  return op->emitOpError("invalid number of reassociation groups: found ")
2086  << reassociation.size() << ", expected " << collapsedShape.size();
2087 
2088  // The next expected expanded dimension index (while iterating over
2089  // reassociation indices).
2090  int64_t nextDim = 0;
2091  for (const auto &it : llvm::enumerate(reassociation)) {
2092  ReassociationIndices group = it.value();
2093  int64_t collapsedDim = it.index();
2094 
2095  bool foundDynamic = false;
2096  for (int64_t expandedDim : group) {
2097  if (expandedDim != nextDim++)
2098  return op->emitOpError("reassociation indices must be contiguous");
2099 
2100  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2101  return op->emitOpError("reassociation index ")
2102  << expandedDim << " is out of bounds";
2103 
2104  // Check if there are multiple dynamic dims in a reassociation group.
2105  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2106  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2107  return op->emitOpError(
2108  "at most one dimension in a reassociation group may be dynamic");
2109  foundDynamic = true;
2110  }
2111  }
2112 
2113  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2114  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2115  return op->emitOpError("collapsed dim (")
2116  << collapsedDim
2117  << ") must be dynamic if and only if reassociation group is "
2118  "dynamic";
2119 
2120  // If all dims in the reassociation group are static, the size of the
2121  // collapsed dim can be verified.
2122  if (!foundDynamic) {
2123  int64_t groupSize = 1;
2124  for (int64_t expandedDim : group)
2125  groupSize *= expandedShape[expandedDim];
2126  if (groupSize != collapsedShape[collapsedDim])
2127  return op->emitOpError("collapsed dim size (")
2128  << collapsedShape[collapsedDim]
2129  << ") must equal reassociation group size (" << groupSize << ")";
2130  }
2131  }
2132 
2133  if (collapsedShape.empty()) {
2134  // Rank 0: All expanded dimensions must be 1.
2135  for (int64_t d : expandedShape)
2136  if (d != 1)
2137  return op->emitOpError(
2138  "rank 0 memrefs can only be extended/collapsed with/from ones");
2139  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2140  // Rank >= 1: Number of dimensions among all reassociation groups must match
2141  // the result memref rank.
2142  return op->emitOpError("expanded rank (")
2143  << expandedShape.size()
2144  << ") inconsistent with number of reassociation indices (" << nextDim
2145  << ")";
2146  }
2147 
2148  return success();
2149 }
2150 
2151 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2152  return getSymbolLessAffineMaps(getReassociationExprs());
2153 }
2154 
2155 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2157  getReassociationIndices());
2158 }
2159 
2160 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2161  return getSymbolLessAffineMaps(getReassociationExprs());
2162 }
2163 
2164 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2166  getReassociationIndices());
2167 }
2168 
2169 /// Compute the layout map after expanding a given source MemRef type with the
2170 /// specified reassociation indices.
2171 static FailureOr<StridedLayoutAttr>
2172 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2173  ArrayRef<ReassociationIndices> reassociation) {
2174  int64_t srcOffset;
2175  SmallVector<int64_t> srcStrides;
2176  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2177  return failure();
2178  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2179 
2180  // 1-1 mapping between srcStrides and reassociation packs.
2181  // Each srcStride starts with the given value and gets expanded according to
2182  // the proper entries in resultShape.
2183  // Example:
2184  // srcStrides = [10000, 1 , 100 ],
2185  // reassociations = [ [0], [1], [2, 3, 4]],
2186  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2187  // -> For the purpose of stride calculation, the useful sizes are:
2188  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2189  // resultStrides = [10000, 1, 600, 200, 100]
2190  // Note that a stride does not get expanded along the first entry of each
2191  // shape pack.
2192  SmallVector<int64_t> reverseResultStrides;
2193  reverseResultStrides.reserve(resultShape.size());
2194  unsigned shapeIndex = resultShape.size() - 1;
2195  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2196  ReassociationIndices reassoc = std::get<0>(it);
2197  int64_t currentStrideToExpand = std::get<1>(it);
2198  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2199  reverseResultStrides.push_back(currentStrideToExpand);
2200  currentStrideToExpand =
2201  (SaturatedInteger::wrap(currentStrideToExpand) *
2202  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2203  .asInteger();
2204  }
2205  }
2206  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2207  resultStrides.resize(resultShape.size(), 1);
2208  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2209 }
2210 
2211 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2212  MemRefType srcType, ArrayRef<int64_t> resultShape,
2213  ArrayRef<ReassociationIndices> reassociation) {
2214  if (srcType.getLayout().isIdentity()) {
2215  // If the source is contiguous (i.e., no layout map specified), so is the
2216  // result.
2217  MemRefLayoutAttrInterface layout;
2218  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2219  srcType.getMemorySpace());
2220  }
2221 
2222  // Source may not be contiguous. Compute the layout map.
2223  FailureOr<StridedLayoutAttr> computedLayout =
2224  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2225  if (failed(computedLayout))
2226  return failure();
2227  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2228  srcType.getMemorySpace());
2229 }
2230 
2231 FailureOr<SmallVector<OpFoldResult>>
2232 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2233  MemRefType expandedType,
2234  ArrayRef<ReassociationIndices> reassociation,
2235  ArrayRef<OpFoldResult> inputShape) {
2236  std::optional<SmallVector<OpFoldResult>> outputShape =
2237  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2238  inputShape);
2239  if (!outputShape)
2240  return failure();
2241  return *outputShape;
2242 }
2243 
2244 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2245  Type resultType, Value src,
2246  ArrayRef<ReassociationIndices> reassociation,
2247  ArrayRef<OpFoldResult> outputShape) {
2248  auto [staticOutputShape, dynamicOutputShape] =
2250  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2251  getReassociationIndicesAttribute(builder, reassociation),
2252  dynamicOutputShape, staticOutputShape);
2253 }
2254 
2255 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2256  Type resultType, Value src,
2257  ArrayRef<ReassociationIndices> reassociation) {
2258  SmallVector<OpFoldResult> inputShape =
2259  getMixedSizes(builder, result.location, src);
2260  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2261  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2262  builder, result.location, memrefResultTy, reassociation, inputShape);
2263  // Failure of this assertion usually indicates presence of multiple
2264  // dynamic dimensions in the same reassociation group.
2265  assert(succeeded(outputShape) && "unable to infer output shape");
2266  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2267 }
2268 
2269 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2270  ArrayRef<int64_t> resultShape, Value src,
2271  ArrayRef<ReassociationIndices> reassociation) {
2272  // Only ranked memref source values are supported.
2273  auto srcType = llvm::cast<MemRefType>(src.getType());
2274  FailureOr<MemRefType> resultType =
2275  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2276  // Failure of this assertion usually indicates a problem with the source
2277  // type, e.g., could not get strides/offset.
2278  assert(succeeded(resultType) && "could not compute layout");
2279  build(builder, result, *resultType, src, reassociation);
2280 }
2281 
2282 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2283  ArrayRef<int64_t> resultShape, Value src,
2284  ArrayRef<ReassociationIndices> reassociation,
2285  ArrayRef<OpFoldResult> outputShape) {
2286  // Only ranked memref source values are supported.
2287  auto srcType = llvm::cast<MemRefType>(src.getType());
2288  FailureOr<MemRefType> resultType =
2289  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2290  // Failure of this assertion usually indicates a problem with the source
2291  // type, e.g., could not get strides/offset.
2292  assert(succeeded(resultType) && "could not compute layout");
2293  build(builder, result, *resultType, src, reassociation, outputShape);
2294 }
2295 
2296 LogicalResult ExpandShapeOp::verify() {
2297  MemRefType srcType = getSrcType();
2298  MemRefType resultType = getResultType();
2299 
2300  if (srcType.getRank() > resultType.getRank()) {
2301  auto r0 = srcType.getRank();
2302  auto r1 = resultType.getRank();
2303  return emitOpError("has source rank ")
2304  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2305  << r0 << " > " << r1 << ").";
2306  }
2307 
2308  // Verify result shape.
2309  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2310  resultType.getShape(),
2311  getReassociationIndices(),
2312  /*allowMultipleDynamicDimsPerGroup=*/true)))
2313  return failure();
2314 
2315  // Compute expected result type (including layout map).
2316  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2317  srcType, resultType.getShape(), getReassociationIndices());
2318  if (failed(expectedResultType))
2319  return emitOpError("invalid source layout map");
2320 
2321  // Check actual result type.
2322  if (*expectedResultType != resultType)
2323  return emitOpError("expected expanded type to be ")
2324  << *expectedResultType << " but found " << resultType;
2325 
2326  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2327  return emitOpError("expected number of static shape bounds to be equal to "
2328  "the output rank (")
2329  << resultType.getRank() << ") but found "
2330  << getStaticOutputShape().size() << " inputs instead";
2331 
2332  if ((int64_t)getOutputShape().size() !=
2333  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2334  return emitOpError("mismatch in dynamic dims in output_shape and "
2335  "static_output_shape: static_output_shape has ")
2336  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2337  << " dynamic dims while output_shape has " << getOutputShape().size()
2338  << " values";
2339 
2340  // Verify if provided output shapes are in agreement with output type.
2341  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2342  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2343  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2344  if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2345  return emitOpError("invalid output shape provided at pos ") << pos;
2346  }
2347  }
2348 
2349  return success();
2350 }
2351 
2352 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2353  MLIRContext *context) {
2354  results.add<
2357 }
2358 
2359 /// Compute the layout map after collapsing a given source MemRef type with the
2360 /// specified reassociation indices.
2361 ///
2362 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2363 /// not possible to check this by inspecting a MemRefType in the general case.
2364 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2365 /// be valid (and thus accepted by this function) unless `strict = true`.
2366 static FailureOr<StridedLayoutAttr>
2367 computeCollapsedLayoutMap(MemRefType srcType,
2368  ArrayRef<ReassociationIndices> reassociation,
2369  bool strict = false) {
2370  int64_t srcOffset;
2371  SmallVector<int64_t> srcStrides;
2372  auto srcShape = srcType.getShape();
2373  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2374  return failure();
2375 
2376  // The result stride of a reassociation group is the stride of the last entry
2377  // of the reassociation. (TODO: Should be the minimum stride in the
2378  // reassociation because strides are not necessarily sorted. E.g., when using
2379  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2380  // strides are meaningless and could have any arbitrary value.
2381  SmallVector<int64_t> resultStrides;
2382  resultStrides.reserve(reassociation.size());
2383  for (const ReassociationIndices &reassoc : reassociation) {
2384  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2385  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2386  ref = ref.drop_back();
2387  if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2388  resultStrides.push_back(srcStrides[ref.back()]);
2389  } else {
2390  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2391  // the corresponding stride may have to be skipped. (See above comment.)
2392  // Therefore, the result stride cannot be statically determined and must
2393  // be dynamic.
2394  resultStrides.push_back(ShapedType::kDynamic);
2395  }
2396  }
2397 
2398  // Validate that each reassociation group is contiguous.
2399  unsigned resultStrideIndex = resultStrides.size() - 1;
2400  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2401  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2402  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2403  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2404  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2405 
2406  // Both source and result stride must have the same static value. In that
2407  // case, we can be sure, that the dimensions are collapsible (because they
2408  // are contiguous).
2409  // If `strict = false` (default during op verification), we accept cases
2410  // where one or both strides are dynamic. This is best effort: We reject
2411  // ops where obviously non-contiguous dims are collapsed, but accept ops
2412  // where we cannot be sure statically. Such ops may fail at runtime. See
2413  // the op documentation for details.
2414  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2415  if (strict && (stride.saturated || srcStride.saturated))
2416  return failure();
2417 
2418  // Dimensions of size 1 should be skipped, because their strides are
2419  // meaningless and could have any arbitrary value.
2420  if (srcShape[idx - 1] == 1)
2421  continue;
2422 
2423  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2424  return failure();
2425  }
2426  }
2427  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2428 }
2429 
2430 bool CollapseShapeOp::isGuaranteedCollapsible(
2431  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2432  // MemRefs with identity layout are always collapsible.
2433  if (srcType.getLayout().isIdentity())
2434  return true;
2435 
2436  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2437  /*strict=*/true));
2438 }
2439 
2440 MemRefType CollapseShapeOp::computeCollapsedType(
2441  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2442  SmallVector<int64_t> resultShape;
2443  resultShape.reserve(reassociation.size());
2444  for (const ReassociationIndices &group : reassociation) {
2445  auto groupSize = SaturatedInteger::wrap(1);
2446  for (int64_t srcDim : group)
2447  groupSize =
2448  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2449  resultShape.push_back(groupSize.asInteger());
2450  }
2451 
2452  if (srcType.getLayout().isIdentity()) {
2453  // If the source is contiguous (i.e., no layout map specified), so is the
2454  // result.
2455  MemRefLayoutAttrInterface layout;
2456  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2457  srcType.getMemorySpace());
2458  }
2459 
2460  // Source may not be fully contiguous. Compute the layout map.
2461  // Note: Dimensions that are collapsed into a single dim are assumed to be
2462  // contiguous.
2463  FailureOr<StridedLayoutAttr> computedLayout =
2464  computeCollapsedLayoutMap(srcType, reassociation);
2465  assert(succeeded(computedLayout) &&
2466  "invalid source layout map or collapsing non-contiguous dims");
2467  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2468  srcType.getMemorySpace());
2469 }
2470 
2471 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2472  ArrayRef<ReassociationIndices> reassociation,
2473  ArrayRef<NamedAttribute> attrs) {
2474  auto srcType = llvm::cast<MemRefType>(src.getType());
2475  MemRefType resultType =
2476  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2478  getReassociationIndicesAttribute(b, reassociation));
2479  build(b, result, resultType, src, attrs);
2480 }
2481 
2482 LogicalResult CollapseShapeOp::verify() {
2483  MemRefType srcType = getSrcType();
2484  MemRefType resultType = getResultType();
2485 
2486  if (srcType.getRank() < resultType.getRank()) {
2487  auto r0 = srcType.getRank();
2488  auto r1 = resultType.getRank();
2489  return emitOpError("has source rank ")
2490  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2491  << r0 << " < " << r1 << ").";
2492  }
2493 
2494  // Verify result shape.
2495  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2496  srcType.getShape(), getReassociationIndices(),
2497  /*allowMultipleDynamicDimsPerGroup=*/true)))
2498  return failure();
2499 
2500  // Compute expected result type (including layout map).
2501  MemRefType expectedResultType;
2502  if (srcType.getLayout().isIdentity()) {
2503  // If the source is contiguous (i.e., no layout map specified), so is the
2504  // result.
2505  MemRefLayoutAttrInterface layout;
2506  expectedResultType =
2507  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2508  srcType.getMemorySpace());
2509  } else {
2510  // Source may not be fully contiguous. Compute the layout map.
2511  // Note: Dimensions that are collapsed into a single dim are assumed to be
2512  // contiguous.
2513  FailureOr<StridedLayoutAttr> computedLayout =
2514  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2515  if (failed(computedLayout))
2516  return emitOpError(
2517  "invalid source layout map or collapsing non-contiguous dims");
2518  expectedResultType =
2519  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2520  *computedLayout, srcType.getMemorySpace());
2521  }
2522 
2523  if (expectedResultType != resultType)
2524  return emitOpError("expected collapsed type to be ")
2525  << expectedResultType << " but found " << resultType;
2526 
2527  return success();
2528 }
2529 
2531  : public OpRewritePattern<CollapseShapeOp> {
2532 public:
2534 
2535  LogicalResult matchAndRewrite(CollapseShapeOp op,
2536  PatternRewriter &rewriter) const override {
2537  auto cast = op.getOperand().getDefiningOp<CastOp>();
2538  if (!cast)
2539  return failure();
2540 
2541  if (!CastOp::canFoldIntoConsumerOp(cast))
2542  return failure();
2543 
2544  Type newResultType = CollapseShapeOp::computeCollapsedType(
2545  llvm::cast<MemRefType>(cast.getOperand().getType()),
2546  op.getReassociationIndices());
2547 
2548  if (newResultType == op.getResultType()) {
2549  rewriter.modifyOpInPlace(
2550  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2551  } else {
2552  Value newOp =
2553  CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2554  op.getReassociationIndices());
2555  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2556  }
2557  return success();
2558  }
2559 };
2560 
2561 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2562  MLIRContext *context) {
2563  results.add<
2565  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2566  memref::DimOp, MemRefType>,
2568 }
2569 
2570 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2571  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2572  adaptor.getOperands());
2573 }
2574 
2575 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2576  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2577  adaptor.getOperands());
2578 }
2579 
2580 //===----------------------------------------------------------------------===//
2581 // ReshapeOp
2582 //===----------------------------------------------------------------------===//
2583 
2584 void ReshapeOp::getAsmResultNames(
2585  function_ref<void(Value, StringRef)> setNameFn) {
2586  setNameFn(getResult(), "reshape");
2587 }
2588 
2589 LogicalResult ReshapeOp::verify() {
2590  Type operandType = getSource().getType();
2591  Type resultType = getResult().getType();
2592 
2593  Type operandElementType =
2594  llvm::cast<ShapedType>(operandType).getElementType();
2595  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2596  if (operandElementType != resultElementType)
2597  return emitOpError("element types of source and destination memref "
2598  "types should be the same");
2599 
2600  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2601  if (!operandMemRefType.getLayout().isIdentity())
2602  return emitOpError("source memref type should have identity affine map");
2603 
2604  int64_t shapeSize =
2605  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2606  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2607  if (resultMemRefType) {
2608  if (!resultMemRefType.getLayout().isIdentity())
2609  return emitOpError("result memref type should have identity affine map");
2610  if (shapeSize == ShapedType::kDynamic)
2611  return emitOpError("cannot use shape operand with dynamic length to "
2612  "reshape to statically-ranked memref type");
2613  if (shapeSize != resultMemRefType.getRank())
2614  return emitOpError(
2615  "length of shape operand differs from the result's memref rank");
2616  }
2617  return success();
2618 }
2619 
2620 //===----------------------------------------------------------------------===//
2621 // StoreOp
2622 //===----------------------------------------------------------------------===//
2623 
2624 LogicalResult StoreOp::verify() {
2625  if (getNumOperands() != 2 + getMemRefType().getRank())
2626  return emitOpError("store index operand count not equal to memref rank");
2627 
2628  return success();
2629 }
2630 
2631 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2632  SmallVectorImpl<OpFoldResult> &results) {
2633  /// store(memrefcast) -> store
2634  return foldMemRefCast(*this, getValueToStore());
2635 }
2636 
2637 //===----------------------------------------------------------------------===//
2638 // SubViewOp
2639 //===----------------------------------------------------------------------===//
2640 
2641 void SubViewOp::getAsmResultNames(
2642  function_ref<void(Value, StringRef)> setNameFn) {
2643  setNameFn(getResult(), "subview");
2644 }
2645 
2646 /// A subview result type can be fully inferred from the source type and the
2647 /// static representation of offsets, sizes and strides. Special sentinels
2648 /// encode the dynamic case.
2649 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2650  ArrayRef<int64_t> staticOffsets,
2651  ArrayRef<int64_t> staticSizes,
2652  ArrayRef<int64_t> staticStrides) {
2653  unsigned rank = sourceMemRefType.getRank();
2654  (void)rank;
2655  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2656  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2657  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2658 
2659  // Extract source offset and strides.
2660  auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2661 
2662  // Compute target offset whose value is:
2663  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2664  int64_t targetOffset = sourceOffset;
2665  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2666  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2667  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2668  SaturatedInteger::wrap(staticOffset) *
2669  SaturatedInteger::wrap(sourceStride))
2670  .asInteger();
2671  }
2672 
2673  // Compute target stride whose value is:
2674  // `sourceStrides_i * staticStrides_i`.
2675  SmallVector<int64_t, 4> targetStrides;
2676  targetStrides.reserve(staticOffsets.size());
2677  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2678  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2679  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2680  SaturatedInteger::wrap(staticStride))
2681  .asInteger());
2682  }
2683 
2684  // The type is now known.
2685  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2686  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2687  targetOffset, targetStrides),
2688  sourceMemRefType.getMemorySpace());
2689 }
2690 
2691 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2692  ArrayRef<OpFoldResult> offsets,
2693  ArrayRef<OpFoldResult> sizes,
2694  ArrayRef<OpFoldResult> strides) {
2695  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2696  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2697  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2698  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2699  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2700  if (!hasValidSizesOffsets(staticOffsets))
2701  return {};
2702  if (!hasValidSizesOffsets(staticSizes))
2703  return {};
2704  if (!hasValidStrides(staticStrides))
2705  return {};
2706  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2707  staticSizes, staticStrides);
2708 }
2709 
2710 MemRefType SubViewOp::inferRankReducedResultType(
2711  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2712  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2713  ArrayRef<int64_t> strides) {
2714  MemRefType inferredType =
2715  inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2716  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2717  "expected ");
2718  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2719  return inferredType;
2720 
2721  // Compute which dimensions are dropped.
2722  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2723  computeRankReductionMask(inferredType.getShape(), resultShape);
2724  assert(dimsToProject.has_value() && "invalid rank reduction");
2725 
2726  // Compute the layout and result type.
2727  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2728  SmallVector<int64_t> rankReducedStrides;
2729  rankReducedStrides.reserve(resultShape.size());
2730  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2731  if (!dimsToProject->contains(idx))
2732  rankReducedStrides.push_back(value);
2733  }
2734  return MemRefType::get(resultShape, inferredType.getElementType(),
2735  StridedLayoutAttr::get(inferredLayout.getContext(),
2736  inferredLayout.getOffset(),
2737  rankReducedStrides),
2738  inferredType.getMemorySpace());
2739 }
2740 
2741 MemRefType SubViewOp::inferRankReducedResultType(
2742  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2744  ArrayRef<OpFoldResult> strides) {
2745  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2746  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2747  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2748  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2749  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2750  return SubViewOp::inferRankReducedResultType(
2751  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2752  staticStrides);
2753 }
2754 
2755 // Build a SubViewOp with mixed static and dynamic entries and custom result
2756 // type. If the type passed is nullptr, it is inferred.
2757 void SubViewOp::build(OpBuilder &b, OperationState &result,
2758  MemRefType resultType, Value source,
2759  ArrayRef<OpFoldResult> offsets,
2760  ArrayRef<OpFoldResult> sizes,
2761  ArrayRef<OpFoldResult> strides,
2762  ArrayRef<NamedAttribute> attrs) {
2763  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2764  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2765  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2766  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2767  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2768  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2769  // Structuring implementation this way avoids duplication between builders.
2770  if (!resultType) {
2771  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2772  staticSizes, staticStrides);
2773  }
2774  result.addAttributes(attrs);
2775  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2776  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2777  b.getDenseI64ArrayAttr(staticSizes),
2778  b.getDenseI64ArrayAttr(staticStrides));
2779 }
2780 
2781 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2782 // type.
2783 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2784  ArrayRef<OpFoldResult> offsets,
2785  ArrayRef<OpFoldResult> sizes,
2786  ArrayRef<OpFoldResult> strides,
2787  ArrayRef<NamedAttribute> attrs) {
2788  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2789 }
2790 
2791 // Build a SubViewOp with static entries and inferred result type.
2792 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2793  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2794  ArrayRef<int64_t> strides,
2795  ArrayRef<NamedAttribute> attrs) {
2796  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2797  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2798  return b.getI64IntegerAttr(v);
2799  }));
2800  SmallVector<OpFoldResult> sizeValues =
2801  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2802  return b.getI64IntegerAttr(v);
2803  }));
2804  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2805  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2806  return b.getI64IntegerAttr(v);
2807  }));
2808  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2809 }
2810 
2811 // Build a SubViewOp with dynamic entries and custom result type. If the
2812 // type passed is nullptr, it is inferred.
2813 void SubViewOp::build(OpBuilder &b, OperationState &result,
2814  MemRefType resultType, Value source,
2815  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2816  ArrayRef<int64_t> strides,
2817  ArrayRef<NamedAttribute> attrs) {
2818  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2819  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2820  return b.getI64IntegerAttr(v);
2821  }));
2822  SmallVector<OpFoldResult> sizeValues =
2823  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2824  return b.getI64IntegerAttr(v);
2825  }));
2826  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2827  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2828  return b.getI64IntegerAttr(v);
2829  }));
2830  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2831  attrs);
2832 }
2833 
2834 // Build a SubViewOp with dynamic entries and custom result type. If the type
2835 // passed is nullptr, it is inferred.
2836 void SubViewOp::build(OpBuilder &b, OperationState &result,
2837  MemRefType resultType, Value source, ValueRange offsets,
2838  ValueRange sizes, ValueRange strides,
2839  ArrayRef<NamedAttribute> attrs) {
2840  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2841  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2842  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2843  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2844  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2845  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2846  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2847 }
2848 
2849 // Build a SubViewOp with dynamic entries and inferred result type.
2850 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2851  ValueRange offsets, ValueRange sizes, ValueRange strides,
2852  ArrayRef<NamedAttribute> attrs) {
2853  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2854 }
2855 
2856 /// For ViewLikeOpInterface.
2857 Value SubViewOp::getViewSource() { return getSource(); }
2858 
2859 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2860 /// static value).
2861 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2862  int64_t t1Offset, t2Offset;
2863  SmallVector<int64_t> t1Strides, t2Strides;
2864  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2865  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2866  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2867 }
2868 
2869 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2870 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2871 /// marked as dropped in `droppedDims`.
2872 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2873  const llvm::SmallBitVector &droppedDims) {
2874  assert(size_t(t1.getRank()) == droppedDims.size() &&
2875  "incorrect number of bits");
2876  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2877  "incorrect number of dropped dims");
2878  int64_t t1Offset, t2Offset;
2879  SmallVector<int64_t> t1Strides, t2Strides;
2880  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2881  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2882  if (failed(res1) || failed(res2))
2883  return false;
2884  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2885  if (droppedDims[i])
2886  continue;
2887  if (t1Strides[i] != t2Strides[j])
2888  return false;
2889  ++j;
2890  }
2891  return true;
2892 }
2893 
2895  SubViewOp op, Type expectedType) {
2896  auto memrefType = llvm::cast<ShapedType>(expectedType);
2897  switch (result) {
2899  return success();
2901  return op->emitError("expected result rank to be smaller or equal to ")
2902  << "the source rank, but got " << op.getType();
2904  return op->emitError("expected result type to be ")
2905  << expectedType
2906  << " or a rank-reduced version. (mismatch of result sizes), but got "
2907  << op.getType();
2909  return op->emitError("expected result element type to be ")
2910  << memrefType.getElementType() << ", but got " << op.getType();
2912  return op->emitError(
2913  "expected result and source memory spaces to match, but got ")
2914  << op.getType();
2916  return op->emitError("expected result type to be ")
2917  << expectedType
2918  << " or a rank-reduced version. (mismatch of result layout), but "
2919  "got "
2920  << op.getType();
2921  }
2922  llvm_unreachable("unexpected subview verification result");
2923 }
2924 
2925 /// Verifier for SubViewOp.
2926 LogicalResult SubViewOp::verify() {
2927  MemRefType baseType = getSourceType();
2928  MemRefType subViewType = getType();
2929  ArrayRef<int64_t> staticOffsets = getStaticOffsets();
2930  ArrayRef<int64_t> staticSizes = getStaticSizes();
2931  ArrayRef<int64_t> staticStrides = getStaticStrides();
2932 
2933  // The base memref and the view memref should be in the same memory space.
2934  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2935  return emitError("different memory spaces specified for base memref "
2936  "type ")
2937  << baseType << " and subview memref type " << subViewType;
2938 
2939  // Verify that the base memref type has a strided layout map.
2940  if (!baseType.isStrided())
2941  return emitError("base type ") << baseType << " is not strided";
2942 
2943  // Compute the expected result type, assuming that there are no rank
2944  // reductions.
2945  MemRefType expectedType = SubViewOp::inferResultType(
2946  baseType, staticOffsets, staticSizes, staticStrides);
2947 
2948  // Verify all properties of a shaped type: rank, element type and dimension
2949  // sizes. This takes into account potential rank reductions.
2950  auto shapedTypeVerification = isRankReducedType(
2951  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2952  if (shapedTypeVerification != SliceVerificationResult::Success)
2953  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2954 
2955  // Make sure that the memory space did not change.
2956  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2958  *this, expectedType);
2959 
2960  // Verify the offset of the layout map.
2961  if (!haveCompatibleOffsets(expectedType, subViewType))
2963  *this, expectedType);
2964 
2965  // The only thing that's left to verify now are the strides. First, compute
2966  // the unused dimensions due to rank reductions. We have to look at sizes and
2967  // strides to decide which dimensions were dropped. This function also
2968  // partially verifies strides in case of rank reductions.
2969  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
2970  getMixedSizes());
2971  if (failed(unusedDims))
2973  *this, expectedType);
2974 
2975  // Strides must match.
2976  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
2978  *this, expectedType);
2979 
2980  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2981  // to the base memref.
2982  SliceBoundsVerificationResult boundsResult =
2983  verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
2984  staticStrides, /*generateErrorMessage=*/true);
2985  if (!boundsResult.isValid)
2986  return getOperation()->emitError(boundsResult.errorMessage);
2987 
2988  return success();
2989 }
2990 
2991 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2992  return os << "range " << range.offset << ":" << range.size << ":"
2993  << range.stride;
2994 }
2995 
2996 /// Return the list of Range (i.e. offset, size, stride). Each Range
2997 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2998 /// with `b` at location `loc`.
2999 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3000  OpBuilder &b, Location loc) {
3001  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3002  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3003  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3005  unsigned rank = ranks[0];
3006  res.reserve(rank);
3007  for (unsigned idx = 0; idx < rank; ++idx) {
3008  Value offset =
3009  op.isDynamicOffset(idx)
3010  ? op.getDynamicOffset(idx)
3011  : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
3012  Value size =
3013  op.isDynamicSize(idx)
3014  ? op.getDynamicSize(idx)
3015  : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
3016  Value stride =
3017  op.isDynamicStride(idx)
3018  ? op.getDynamicStride(idx)
3019  : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
3020  res.emplace_back(Range{offset, size, stride});
3021  }
3022  return res;
3023 }
3024 
3025 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3026 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3027 /// the rank of the inferred result type if `currentResultType` is lower rank
3028 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3029 /// together with the result type. In this case, it is important to compute
3030 /// the dropped dimensions using `currentSourceType` whose strides align with
3031 /// `currentResultType`.
3033  MemRefType currentResultType, MemRefType currentSourceType,
3034  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3035  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3036  MemRefType nonRankReducedType = SubViewOp::inferResultType(
3037  sourceType, mixedOffsets, mixedSizes, mixedStrides);
3038  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3039  currentSourceType, currentResultType, mixedSizes);
3040  if (failed(unusedDims))
3041  return nullptr;
3042 
3043  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3044  SmallVector<int64_t> shape, strides;
3045  unsigned numDimsAfterReduction =
3046  nonRankReducedType.getRank() - unusedDims->count();
3047  shape.reserve(numDimsAfterReduction);
3048  strides.reserve(numDimsAfterReduction);
3049  for (const auto &[idx, size, stride] :
3050  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3051  nonRankReducedType.getShape(), layout.getStrides())) {
3052  if (unusedDims->test(idx))
3053  continue;
3054  shape.push_back(size);
3055  strides.push_back(stride);
3056  }
3057 
3058  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3059  StridedLayoutAttr::get(sourceType.getContext(),
3060  layout.getOffset(), strides),
3061  nonRankReducedType.getMemorySpace());
3062 }
3063 
3065  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3066  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3067  unsigned rank = memrefType.getRank();
3068  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3069  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3070  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3071  MemRefType targetType = SubViewOp::inferRankReducedResultType(
3072  targetShape, memrefType, offsets, sizes, strides);
3073  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3074  sizes, strides);
3075 }
3076 
3077 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3078  Value value,
3079  ArrayRef<int64_t> desiredShape) {
3080  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3081  assert(sourceMemrefType && "not a ranked memref type");
3082  auto sourceShape = sourceMemrefType.getShape();
3083  if (sourceShape.equals(desiredShape))
3084  return value;
3085  auto maybeRankReductionMask =
3086  mlir::computeRankReductionMask(sourceShape, desiredShape);
3087  if (!maybeRankReductionMask)
3088  return failure();
3089  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3090 }
3091 
3092 /// Helper method to check if a `subview` operation is trivially a no-op. This
3093 /// is the case if the all offsets are zero, all strides are 1, and the source
3094 /// shape is same as the size of the subview. In such cases, the subview can
3095 /// be folded into its source.
3096 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3097  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3098  return false;
3099 
3100  auto mixedOffsets = subViewOp.getMixedOffsets();
3101  auto mixedSizes = subViewOp.getMixedSizes();
3102  auto mixedStrides = subViewOp.getMixedStrides();
3103 
3104  // Check offsets are zero.
3105  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3106  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3107  return !intValue || intValue.value() != 0;
3108  }))
3109  return false;
3110 
3111  // Check strides are one.
3112  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3113  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3114  return !intValue || intValue.value() != 1;
3115  }))
3116  return false;
3117 
3118  // Check all size values are static and matches the (static) source shape.
3119  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3120  for (const auto &size : llvm::enumerate(mixedSizes)) {
3121  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3122  if (!intValue || *intValue != sourceShape[size.index()])
3123  return false;
3124  }
3125  // All conditions met. The `SubViewOp` is foldable as a no-op.
3126  return true;
3127 }
3128 
3129 namespace {
3130 /// Pattern to rewrite a subview op with MemRefCast arguments.
3131 /// This essentially pushes memref.cast past its consuming subview when
3132 /// `canFoldIntoConsumerOp` is true.
3133 ///
3134 /// Example:
3135 /// ```
3136 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3137 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3138 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3139 /// ```
3140 /// is rewritten into:
3141 /// ```
3142 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3143 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3144 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3145 /// ```
3146 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3147 public:
3149 
3150  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3151  PatternRewriter &rewriter) const override {
3152  // Any constant operand, just return to let SubViewOpConstantFolder kick
3153  // in.
3154  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3155  return matchPattern(operand, matchConstantIndex());
3156  }))
3157  return failure();
3158 
3159  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3160  if (!castOp)
3161  return failure();
3162 
3163  if (!CastOp::canFoldIntoConsumerOp(castOp))
3164  return failure();
3165 
3166  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3167  // the MemRefCastOp source operand type to infer the result type and the
3168  // current SubViewOp source operand type to compute the dropped dimensions
3169  // if the operation is rank-reducing.
3170  auto resultType = getCanonicalSubViewResultType(
3171  subViewOp.getType(), subViewOp.getSourceType(),
3172  llvm::cast<MemRefType>(castOp.getSource().getType()),
3173  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3174  subViewOp.getMixedStrides());
3175  if (!resultType)
3176  return failure();
3177 
3178  Value newSubView = SubViewOp::create(
3179  rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3180  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3181  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3182  subViewOp.getStaticStrides());
3183  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3184  newSubView);
3185  return success();
3186  }
3187 };
3188 
3189 /// Canonicalize subview ops that are no-ops. When the source shape is not
3190 /// same as a result shape due to use of `affine_map`.
3191 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3192 public:
3194 
3195  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3196  PatternRewriter &rewriter) const override {
3197  if (!isTrivialSubViewOp(subViewOp))
3198  return failure();
3199  if (subViewOp.getSourceType() == subViewOp.getType()) {
3200  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3201  return success();
3202  }
3203  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3204  subViewOp.getSource());
3205  return success();
3206  }
3207 };
3208 } // namespace
3209 
3210 /// Return the canonical type of the result of a subview.
3212  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3213  ArrayRef<OpFoldResult> mixedSizes,
3214  ArrayRef<OpFoldResult> mixedStrides) {
3215  // Infer a memref type without taking into account any rank reductions.
3216  MemRefType resTy = SubViewOp::inferResultType(
3217  op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3218  if (!resTy)
3219  return {};
3220  MemRefType nonReducedType = resTy;
3221 
3222  // Directly return the non-rank reduced type if there are no dropped dims.
3223  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3224  if (droppedDims.none())
3225  return nonReducedType;
3226 
3227  // Take the strides and offset from the non-rank reduced type.
3228  auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3229 
3230  // Drop dims from shape and strides.
3231  SmallVector<int64_t> targetShape;
3232  SmallVector<int64_t> targetStrides;
3233  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3234  if (droppedDims.test(i))
3235  continue;
3236  targetStrides.push_back(nonReducedStrides[i]);
3237  targetShape.push_back(nonReducedType.getDimSize(i));
3238  }
3239 
3240  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3241  StridedLayoutAttr::get(nonReducedType.getContext(),
3242  offset, targetStrides),
3243  nonReducedType.getMemorySpace());
3244  }
3245 };
3246 
3247 /// A canonicalizer wrapper to replace SubViewOps.
3249  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3250  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3251  }
3252 };
3253 
3254 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3255  MLIRContext *context) {
3256  results
3259  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3260 }
3261 
3262 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3263  MemRefType sourceMemrefType = getSource().getType();
3264  MemRefType resultMemrefType = getResult().getType();
3265  auto resultLayout =
3266  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3267 
3268  if (resultMemrefType == sourceMemrefType &&
3269  resultMemrefType.hasStaticShape() &&
3270  (!resultLayout || resultLayout.hasStaticLayout())) {
3271  return getViewSource();
3272  }
3273 
3274  // Fold subview(subview(x)), where both subviews have the same size and the
3275  // second subview's offsets are all zero. (I.e., the second subview is a
3276  // no-op.)
3277  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3278  auto srcSizes = srcSubview.getMixedSizes();
3279  auto sizes = getMixedSizes();
3280  auto offsets = getMixedOffsets();
3281  bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3282  auto strides = getMixedStrides();
3283  bool allStridesOne = llvm::all_of(strides, isOneInteger);
3284  bool allSizesSame = llvm::equal(sizes, srcSizes);
3285  if (allOffsetsZero && allStridesOne && allSizesSame &&
3286  resultMemrefType == sourceMemrefType)
3287  return getViewSource();
3288  }
3289 
3290  return {};
3291 }
3292 
3293 //===----------------------------------------------------------------------===//
3294 // TransposeOp
3295 //===----------------------------------------------------------------------===//
3296 
3297 void TransposeOp::getAsmResultNames(
3298  function_ref<void(Value, StringRef)> setNameFn) {
3299  setNameFn(getResult(), "transpose");
3300 }
3301 
3302 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3303 static MemRefType inferTransposeResultType(MemRefType memRefType,
3304  AffineMap permutationMap) {
3305  auto originalSizes = memRefType.getShape();
3306  auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3307  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3308 
3309  // Compute permuted sizes and strides.
3310  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3311  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3312 
3313  return MemRefType::Builder(memRefType)
3314  .setShape(sizes)
3315  .setLayout(
3316  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3317 }
3318 
3319 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3320  AffineMapAttr permutation,
3321  ArrayRef<NamedAttribute> attrs) {
3322  auto permutationMap = permutation.getValue();
3323  assert(permutationMap);
3324 
3325  auto memRefType = llvm::cast<MemRefType>(in.getType());
3326  // Compute result type.
3327  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3328 
3329  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3330  build(b, result, resultType, in, attrs);
3331 }
3332 
3333 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3335  p << " " << getIn() << " " << getPermutation();
3336  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3337  p << " : " << getIn().getType() << " to " << getType();
3338 }
3339 
3340 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3342  AffineMap permutation;
3343  MemRefType srcType, dstType;
3344  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3345  parser.parseOptionalAttrDict(result.attributes) ||
3346  parser.parseColonType(srcType) ||
3347  parser.resolveOperand(in, srcType, result.operands) ||
3348  parser.parseKeywordType("to", dstType) ||
3349  parser.addTypeToList(dstType, result.types))
3350  return failure();
3351 
3352  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3353  AffineMapAttr::get(permutation));
3354  return success();
3355 }
3356 
3357 LogicalResult TransposeOp::verify() {
3358  if (!getPermutation().isPermutation())
3359  return emitOpError("expected a permutation map");
3360  if (getPermutation().getNumDims() != getIn().getType().getRank())
3361  return emitOpError("expected a permutation map of same rank as the input");
3362 
3363  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3364  auto resultType = llvm::cast<MemRefType>(getType());
3365  auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3366  .canonicalizeStridedLayout();
3367 
3368  if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3369  return emitOpError("result type ")
3370  << resultType
3371  << " is not equivalent to the canonical transposed input type "
3372  << canonicalResultType;
3373  return success();
3374 }
3375 
3376 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3377  // First check for identity permutation, we can fold it away if input and
3378  // result types are identical already.
3379  if (getPermutation().isIdentity() && getType() == getIn().getType())
3380  return getIn();
3381  // Fold two consecutive memref.transpose Ops into one by composing their
3382  // permutation maps.
3383  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3384  AffineMap composedPermutation =
3385  getPermutation().compose(otherTransposeOp.getPermutation());
3386  getInMutable().assign(otherTransposeOp.getIn());
3387  setPermutation(composedPermutation);
3388  return getResult();
3389  }
3390  return {};
3391 }
3392 
3393 //===----------------------------------------------------------------------===//
3394 // ViewOp
3395 //===----------------------------------------------------------------------===//
3396 
3397 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3398  setNameFn(getResult(), "view");
3399 }
3400 
3401 LogicalResult ViewOp::verify() {
3402  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3403  auto viewType = getType();
3404 
3405  // The base memref should have identity layout map (or none).
3406  if (!baseType.getLayout().isIdentity())
3407  return emitError("unsupported map for base memref type ") << baseType;
3408 
3409  // The result memref should have identity layout map (or none).
3410  if (!viewType.getLayout().isIdentity())
3411  return emitError("unsupported map for result memref type ") << viewType;
3412 
3413  // The base memref and the view memref should be in the same memory space.
3414  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3415  return emitError("different memory spaces specified for base memref "
3416  "type ")
3417  << baseType << " and view memref type " << viewType;
3418 
3419  // Verify that we have the correct number of sizes for the result type.
3420  unsigned numDynamicDims = viewType.getNumDynamicDims();
3421  if (getSizes().size() != numDynamicDims)
3422  return emitError("incorrect number of size operands for type ") << viewType;
3423 
3424  return success();
3425 }
3426 
3427 Value ViewOp::getViewSource() { return getSource(); }
3428 
3429 OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3430  MemRefType sourceMemrefType = getSource().getType();
3431  MemRefType resultMemrefType = getResult().getType();
3432 
3433  if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3434  return getViewSource();
3435 
3436  return {};
3437 }
3438 
3439 namespace {
3440 
3441 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3443 
3444  LogicalResult matchAndRewrite(ViewOp viewOp,
3445  PatternRewriter &rewriter) const override {
3446  // Return if none of the operands are constants.
3447  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3448  return matchPattern(operand, matchConstantIndex());
3449  }))
3450  return failure();
3451 
3452  // Get result memref type.
3453  auto memrefType = viewOp.getType();
3454 
3455  // Get offset from old memref view type 'memRefType'.
3456  int64_t oldOffset;
3457  SmallVector<int64_t, 4> oldStrides;
3458  if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3459  return failure();
3460  assert(oldOffset == 0 && "Expected 0 offset");
3461 
3462  SmallVector<Value, 4> newOperands;
3463 
3464  // Offset cannot be folded into result type.
3465 
3466  // Fold any dynamic dim operands which are produced by a constant.
3467  SmallVector<int64_t, 4> newShapeConstants;
3468  newShapeConstants.reserve(memrefType.getRank());
3469 
3470  unsigned dynamicDimPos = 0;
3471  unsigned rank = memrefType.getRank();
3472  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3473  int64_t dimSize = memrefType.getDimSize(dim);
3474  // If this is already static dimension, keep it.
3475  if (ShapedType::isStatic(dimSize)) {
3476  newShapeConstants.push_back(dimSize);
3477  continue;
3478  }
3479  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3480  if (auto constantIndexOp =
3481  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3482  // Dynamic shape dimension will be folded.
3483  newShapeConstants.push_back(constantIndexOp.value());
3484  } else {
3485  // Dynamic shape dimension not folded; copy operand from old memref.
3486  newShapeConstants.push_back(dimSize);
3487  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3488  }
3489  dynamicDimPos++;
3490  }
3491 
3492  // Create new memref type with constant folded dims.
3493  MemRefType newMemRefType =
3494  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3495  // Nothing new, don't fold.
3496  if (newMemRefType == memrefType)
3497  return failure();
3498 
3499  // Create new ViewOp.
3500  auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
3501  viewOp.getOperand(0), viewOp.getByteShift(),
3502  newOperands);
3503  // Insert a cast so we have the same type as the old memref type.
3504  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3505  return success();
3506  }
3507 };
3508 
3509 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3511 
3512  LogicalResult matchAndRewrite(ViewOp viewOp,
3513  PatternRewriter &rewriter) const override {
3514  Value memrefOperand = viewOp.getOperand(0);
3515  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3516  if (!memrefCastOp)
3517  return failure();
3518  Value allocOperand = memrefCastOp.getOperand();
3519  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3520  if (!allocOp)
3521  return failure();
3522  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3523  viewOp.getByteShift(),
3524  viewOp.getSizes());
3525  return success();
3526  }
3527 };
3528 
3529 } // namespace
3530 
3531 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3532  MLIRContext *context) {
3533  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3534 }
3535 
3536 //===----------------------------------------------------------------------===//
3537 // AtomicRMWOp
3538 //===----------------------------------------------------------------------===//
3539 
3540 LogicalResult AtomicRMWOp::verify() {
3541  if (getMemRefType().getRank() != getNumOperands() - 2)
3542  return emitOpError(
3543  "expects the number of subscripts to be equal to memref rank");
3544  switch (getKind()) {
3545  case arith::AtomicRMWKind::addf:
3546  case arith::AtomicRMWKind::maximumf:
3547  case arith::AtomicRMWKind::minimumf:
3548  case arith::AtomicRMWKind::mulf:
3549  if (!llvm::isa<FloatType>(getValue().getType()))
3550  return emitOpError() << "with kind '"
3551  << arith::stringifyAtomicRMWKind(getKind())
3552  << "' expects a floating-point type";
3553  break;
3554  case arith::AtomicRMWKind::addi:
3555  case arith::AtomicRMWKind::maxs:
3556  case arith::AtomicRMWKind::maxu:
3557  case arith::AtomicRMWKind::mins:
3558  case arith::AtomicRMWKind::minu:
3559  case arith::AtomicRMWKind::muli:
3560  case arith::AtomicRMWKind::ori:
3561  case arith::AtomicRMWKind::andi:
3562  if (!llvm::isa<IntegerType>(getValue().getType()))
3563  return emitOpError() << "with kind '"
3564  << arith::stringifyAtomicRMWKind(getKind())
3565  << "' expects an integer type";
3566  break;
3567  default:
3568  break;
3569  }
3570  return success();
3571 }
3572 
3573 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3574  /// atomicrmw(memrefcast) -> atomicrmw
3575  if (succeeded(foldMemRefCast(*this, getValue())))
3576  return getResult();
3577  return OpFoldResult();
3578 }
3579 
3580 //===----------------------------------------------------------------------===//
3581 // TableGen'd op method definitions
3582 //===----------------------------------------------------------------------===//
3583 
3584 #define GET_OP_CLASSES
3585 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition: MemRefOps.cpp:96
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1471
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2079
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:377
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3303
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:358
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2861
static LogicalResult FoldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
Definition: MemRefOps.cpp:761
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1308
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
Definition: MemRefOps.cpp:2894
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3032
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1485
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:848
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3096
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:400
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2872
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:833
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2172
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2367
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:129
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:140
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:250
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:193
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:226
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:68
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3064
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:2999
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:23
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:448
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:451
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:408
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:411
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2535
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3248
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3249
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3211
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3212
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.