MLIR  22.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::memref;
29 
30 /// Materialize a single constant operation from a given attribute value with
31 /// the desired resultant type.
33  Attribute value, Type type,
34  Location loc) {
35  return arith::ConstantOp::materialize(builder, value, type, loc);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Common canonicalization pattern support logic
40 //===----------------------------------------------------------------------===//
41 
42 /// This is a common class used for patterns of the form
43 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
44 /// into the root operation directly.
45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46  bool folded = false;
47  for (OpOperand &operand : op->getOpOperands()) {
48  auto cast = operand.get().getDefiningOp<CastOp>();
49  if (cast && operand.get() != inner &&
50  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51  operand.set(cast.getOperand());
52  folded = true;
53  }
54  }
55  return success(folded);
56 }
57 
58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
59 /// type.
61  if (auto memref = llvm::dyn_cast<MemRefType>(type))
62  return RankedTensorType::get(memref.getShape(), memref.getElementType());
63  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
64  return UnrankedTensorType::get(memref.getElementType());
65  return NoneType::get(type.getContext());
66 }
67 
69  int64_t dim) {
70  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that sets values[i] to constValues[i] if the latter is a
91 /// static value, as indicated by ShapedType::kDynamic.
92 ///
93 /// If constValues[i] is dynamic, tries to extract a constant value from
94 /// value[i] to allow for additional folding opportunities. Also convertes all
95 /// existing attributes to index attributes. (They may be i64 attributes.)
97  ArrayRef<int64_t> constValues) {
98  assert(constValues.size() == values.size() &&
99  "incorrect number of const values");
100  for (auto [i, cstVal] : llvm::enumerate(constValues)) {
101  Builder builder(values[i].getContext());
102  if (ShapedType::isStatic(cstVal)) {
103  // Constant value is known, use it directly.
104  values[i] = builder.getIndexAttr(cstVal);
105  continue;
106  }
107  if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
108  // Try to extract a constant or convert an existing to index.
109  values[i] = builder.getIndexAttr(*cst);
110  }
111  }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // AllocOp / AllocaOp
116 //===----------------------------------------------------------------------===//
117 
118 void AllocOp::getAsmResultNames(
119  function_ref<void(Value, StringRef)> setNameFn) {
120  setNameFn(getResult(), "alloc");
121 }
122 
123 void AllocaOp::getAsmResultNames(
124  function_ref<void(Value, StringRef)> setNameFn) {
125  setNameFn(getResult(), "alloca");
126 }
127 
128 template <typename AllocLikeOp>
129 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
130  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
131  "applies to only alloc or alloca");
132  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
133  if (!memRefType)
134  return op.emitOpError("result must be a memref");
135 
136  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
137  return op.emitOpError("dimension operand count does not equal memref "
138  "dynamic dimension count");
139 
140  unsigned numSymbols = 0;
141  if (!memRefType.getLayout().isIdentity())
142  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
143  if (op.getSymbolOperands().size() != numSymbols)
144  return op.emitOpError("symbol operand count does not equal memref symbol "
145  "count: expected ")
146  << numSymbols << ", got " << op.getSymbolOperands().size();
147 
148  return success();
149 }
150 
151 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
152 
153 LogicalResult AllocaOp::verify() {
154  // An alloca op needs to have an ancestor with an allocation scope trait.
155  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
156  return emitOpError(
157  "requires an ancestor op with AutomaticAllocationScope trait");
158 
159  return verifyAllocLikeOp(*this);
160 }
161 
162 namespace {
163 /// Fold constant dimensions into an alloc like operation.
164 template <typename AllocLikeOp>
165 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
167 
168  LogicalResult matchAndRewrite(AllocLikeOp alloc,
169  PatternRewriter &rewriter) const override {
170  // Check to see if any dimensions operands are constants. If so, we can
171  // substitute and drop them.
172  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
173  APInt constSizeArg;
174  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
175  return false;
176  return constSizeArg.isNonNegative();
177  }))
178  return failure();
179 
180  auto memrefType = alloc.getType();
181 
182  // Ok, we have one or more constant operands. Collect the non-constant ones
183  // and keep track of the resultant memref type to build.
184  SmallVector<int64_t, 4> newShapeConstants;
185  newShapeConstants.reserve(memrefType.getRank());
186  SmallVector<Value, 4> dynamicSizes;
187 
188  unsigned dynamicDimPos = 0;
189  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
190  int64_t dimSize = memrefType.getDimSize(dim);
191  // If this is already static dimension, keep it.
192  if (ShapedType::isStatic(dimSize)) {
193  newShapeConstants.push_back(dimSize);
194  continue;
195  }
196  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
197  APInt constSizeArg;
198  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
199  constSizeArg.isNonNegative()) {
200  // Dynamic shape dimension will be folded.
201  newShapeConstants.push_back(constSizeArg.getZExtValue());
202  } else {
203  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
204  newShapeConstants.push_back(ShapedType::kDynamic);
205  dynamicSizes.push_back(dynamicSize);
206  }
207  dynamicDimPos++;
208  }
209 
210  // Create new memref type (which will have fewer dynamic dimensions).
211  MemRefType newMemRefType =
212  MemRefType::Builder(memrefType).setShape(newShapeConstants);
213  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
214 
215  // Create and insert the alloc op for the new memref.
216  auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
217  dynamicSizes, alloc.getSymbolOperands(),
218  alloc.getAlignmentAttr());
219  // Insert a cast so we have the same type as the old alloc.
220  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
221  return success();
222  }
223 };
224 
225 /// Fold alloc operations with no users or only store and dealloc uses.
226 template <typename T>
227 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
229 
230  LogicalResult matchAndRewrite(T alloc,
231  PatternRewriter &rewriter) const override {
232  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
233  if (auto storeOp = dyn_cast<StoreOp>(op))
234  return storeOp.getValue() == alloc;
235  return !isa<DeallocOp>(op);
236  }))
237  return failure();
238 
239  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
240  rewriter.eraseOp(user);
241 
242  rewriter.eraseOp(alloc);
243  return success();
244  }
245 };
246 } // namespace
247 
248 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
249  MLIRContext *context) {
250  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
251 }
252 
253 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
254  MLIRContext *context) {
255  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
256  context);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // ReallocOp
261 //===----------------------------------------------------------------------===//
262 
263 LogicalResult ReallocOp::verify() {
264  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
265  MemRefType resultType = getType();
266 
267  // The source memref should have identity layout (or none).
268  if (!sourceType.getLayout().isIdentity())
269  return emitError("unsupported layout for source memref type ")
270  << sourceType;
271 
272  // The result memref should have identity layout (or none).
273  if (!resultType.getLayout().isIdentity())
274  return emitError("unsupported layout for result memref type ")
275  << resultType;
276 
277  // The source memref and the result memref should be in the same memory space.
278  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
279  return emitError("different memory spaces specified for source memref "
280  "type ")
281  << sourceType << " and result memref type " << resultType;
282 
283  // The source memref and the result memref should have the same element type.
284  if (sourceType.getElementType() != resultType.getElementType())
285  return emitError("different element types specified for source memref "
286  "type ")
287  << sourceType << " and result memref type " << resultType;
288 
289  // Verify that we have the dynamic dimension operand when it is needed.
290  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
291  return emitError("missing dimension operand for result type ")
292  << resultType;
293  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
294  return emitError("unnecessary dimension operand for result type ")
295  << resultType;
296 
297  return success();
298 }
299 
300 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
301  MLIRContext *context) {
302  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // AllocaScopeOp
307 //===----------------------------------------------------------------------===//
308 
310  bool printBlockTerminators = false;
311 
312  p << ' ';
313  if (!getResults().empty()) {
314  p << " -> (" << getResultTypes() << ")";
315  printBlockTerminators = true;
316  }
317  p << ' ';
318  p.printRegion(getBodyRegion(),
319  /*printEntryBlockArgs=*/false,
320  /*printBlockTerminators=*/printBlockTerminators);
321  p.printOptionalAttrDict((*this)->getAttrs());
322 }
323 
324 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
325  // Create a region for the body.
326  result.regions.reserve(1);
327  Region *bodyRegion = result.addRegion();
328 
329  // Parse optional results type list.
330  if (parser.parseOptionalArrowTypeList(result.types))
331  return failure();
332 
333  // Parse the body region.
334  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
335  return failure();
336  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
337  result.location);
338 
339  // Parse the optional attribute list.
340  if (parser.parseOptionalAttrDict(result.attributes))
341  return failure();
342 
343  return success();
344 }
345 
346 void AllocaScopeOp::getSuccessorRegions(
348  if (!point.isParent()) {
349  regions.push_back(RegionSuccessor(getResults()));
350  return;
351  }
352 
353  regions.push_back(RegionSuccessor(&getBodyRegion()));
354 }
355 
356 /// Given an operation, return whether this op is guaranteed to
357 /// allocate an AutomaticAllocationScopeResource
359  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
360  if (!interface)
361  return false;
362  for (auto res : op->getResults()) {
363  if (auto effect =
364  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
365  if (isa<SideEffects::AutomaticAllocationScopeResource>(
366  effect->getResource()))
367  return true;
368  }
369  }
370  return false;
371 }
372 
373 /// Given an operation, return whether this op itself could
374 /// allocate an AutomaticAllocationScopeResource. Note that
375 /// this will not check whether an operation contained within
376 /// the op can allocate.
378  // This op itself doesn't create a stack allocation,
379  // the inner allocation should be handled separately.
381  return false;
382  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
383  if (!interface)
384  return true;
385  for (auto res : op->getResults()) {
386  if (auto effect =
387  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
388  if (isa<SideEffects::AutomaticAllocationScopeResource>(
389  effect->getResource()))
390  return true;
391  }
392  }
393  return false;
394 }
395 
396 /// Return whether this op is the last non terminating op
397 /// in a region. That is to say, it is in a one-block region
398 /// and is only followed by a terminator. This prevents
399 /// extending the lifetime of allocations.
401  return op->getBlock()->mightHaveTerminator() &&
402  op->getNextNode() == op->getBlock()->getTerminator() &&
403  op->getParentRegion()->hasOneBlock();
404 }
405 
406 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
407 /// or it contains no allocation.
408 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
410 
411  LogicalResult matchAndRewrite(AllocaScopeOp op,
412  PatternRewriter &rewriter) const override {
413  bool hasPotentialAlloca =
414  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
415  if (alloc == op)
416  return WalkResult::advance();
418  return WalkResult::interrupt();
419  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
420  return WalkResult::skip();
421  return WalkResult::advance();
422  }).wasInterrupted();
423 
424  // If this contains no potential allocation, it is always legal to
425  // inline. Otherwise, consider two conditions:
426  if (hasPotentialAlloca) {
427  // If the parent isn't an allocation scope, or we are not the last
428  // non-terminator op in the parent, we will extend the lifetime.
429  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
430  return failure();
431  if (!lastNonTerminatorInRegion(op))
432  return failure();
433  }
434 
435  Block *block = &op.getRegion().front();
436  Operation *terminator = block->getTerminator();
437  ValueRange results = terminator->getOperands();
438  rewriter.inlineBlockBefore(block, op);
439  rewriter.replaceOp(op, results);
440  rewriter.eraseOp(terminator);
441  return success();
442  }
443 };
444 
445 /// Move allocations into an allocation scope, if it is legal to
446 /// move them (e.g. their operands are available at the location
447 /// the op would be moved to).
448 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
450 
451  LogicalResult matchAndRewrite(AllocaScopeOp op,
452  PatternRewriter &rewriter) const override {
453 
454  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
455  return failure();
456 
457  Operation *lastParentWithoutScope = op->getParentOp();
458 
459  if (!lastParentWithoutScope ||
460  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
461  return failure();
462 
463  // Only apply to if this is this last non-terminator
464  // op in the block (lest lifetime be extended) of a one
465  // block region
466  if (!lastNonTerminatorInRegion(op) ||
467  !lastNonTerminatorInRegion(lastParentWithoutScope))
468  return failure();
469 
470  while (!lastParentWithoutScope->getParentOp()
472  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
473  if (!lastParentWithoutScope ||
474  !lastNonTerminatorInRegion(lastParentWithoutScope))
475  return failure();
476  }
477  assert(lastParentWithoutScope->getParentOp()
479 
480  Region *containingRegion = nullptr;
481  for (auto &r : lastParentWithoutScope->getRegions()) {
482  if (r.isAncestor(op->getParentRegion())) {
483  assert(containingRegion == nullptr &&
484  "only one region can contain the op");
485  containingRegion = &r;
486  }
487  }
488  assert(containingRegion && "op must be contained in a region");
489 
490  SmallVector<Operation *> toHoist;
491  op->walk([&](Operation *alloc) {
493  return WalkResult::skip();
494 
495  // If any operand is not defined before the location of
496  // lastParentWithoutScope (i.e. where we would hoist to), skip.
497  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
498  return containingRegion->isAncestor(v.getParentRegion());
499  }))
500  return WalkResult::skip();
501  toHoist.push_back(alloc);
502  return WalkResult::advance();
503  });
504 
505  if (toHoist.empty())
506  return failure();
507  rewriter.setInsertionPoint(lastParentWithoutScope);
508  for (auto *op : toHoist) {
509  auto *cloned = rewriter.clone(*op);
510  rewriter.replaceOp(op, cloned->getResults());
511  }
512  return success();
513  }
514 };
515 
516 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
517  MLIRContext *context) {
518  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // AssumeAlignmentOp
523 //===----------------------------------------------------------------------===//
524 
525 LogicalResult AssumeAlignmentOp::verify() {
526  if (!llvm::isPowerOf2_32(getAlignment()))
527  return emitOpError("alignment must be power of 2");
528  return success();
529 }
530 
531 void AssumeAlignmentOp::getAsmResultNames(
532  function_ref<void(Value, StringRef)> setNameFn) {
533  setNameFn(getResult(), "assume_align");
534 }
535 
536 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
537  auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
538  if (!source)
539  return {};
540  if (source.getAlignment() != getAlignment())
541  return {};
542  return getMemref();
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // CastOp
547 //===----------------------------------------------------------------------===//
548 
549 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
550  setNameFn(getResult(), "cast");
551 }
552 
553 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
554 /// source memref. This is useful to fold a memref.cast into a consuming op
555 /// and implement canonicalization patterns for ops in different dialects that
556 /// may consume the results of memref.cast operations. Such foldable memref.cast
557 /// operations are typically inserted as `view` and `subview` ops are
558 /// canonicalized, to preserve the type compatibility of their uses.
559 ///
560 /// Returns true when all conditions are met:
561 /// 1. source and result are ranked memrefs with strided semantics and same
562 /// element type and rank.
563 /// 2. each of the source's size, offset or stride has more static information
564 /// than the corresponding result's size, offset or stride.
565 ///
566 /// Example 1:
567 /// ```mlir
568 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
569 /// %2 = consumer %1 ... : memref<?x?xf32> ...
570 /// ```
571 ///
572 /// may fold into:
573 ///
574 /// ```mlir
575 /// %2 = consumer %0 ... : memref<8x16xf32> ...
576 /// ```
577 ///
578 /// Example 2:
579 /// ```
580 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
581 /// to memref<?x?xf32>
582 /// consumer %1 : memref<?x?xf32> ...
583 /// ```
584 ///
585 /// may fold into:
586 ///
587 /// ```
588 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
589 /// ```
590 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
591  MemRefType sourceType =
592  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
593  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
594 
595  // Requires ranked MemRefType.
596  if (!sourceType || !resultType)
597  return false;
598 
599  // Requires same elemental type.
600  if (sourceType.getElementType() != resultType.getElementType())
601  return false;
602 
603  // Requires same rank.
604  if (sourceType.getRank() != resultType.getRank())
605  return false;
606 
607  // Only fold casts between strided memref forms.
608  int64_t sourceOffset, resultOffset;
609  SmallVector<int64_t, 4> sourceStrides, resultStrides;
610  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
611  failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
612  return false;
613 
614  // If cast is towards more static sizes along any dimension, don't fold.
615  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
616  auto ss = std::get<0>(it), st = std::get<1>(it);
617  if (ss != st)
618  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
619  return false;
620  }
621 
622  // If cast is towards more static offset along any dimension, don't fold.
623  if (sourceOffset != resultOffset)
624  if (ShapedType::isDynamic(sourceOffset) &&
625  ShapedType::isStatic(resultOffset))
626  return false;
627 
628  // If cast is towards more static strides along any dimension, don't fold.
629  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
630  auto ss = std::get<0>(it), st = std::get<1>(it);
631  if (ss != st)
632  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
633  return false;
634  }
635 
636  return true;
637 }
638 
639 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
640  if (inputs.size() != 1 || outputs.size() != 1)
641  return false;
642  Type a = inputs.front(), b = outputs.front();
643  auto aT = llvm::dyn_cast<MemRefType>(a);
644  auto bT = llvm::dyn_cast<MemRefType>(b);
645 
646  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
647  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
648 
649  if (aT && bT) {
650  if (aT.getElementType() != bT.getElementType())
651  return false;
652  if (aT.getLayout() != bT.getLayout()) {
653  int64_t aOffset, bOffset;
654  SmallVector<int64_t, 4> aStrides, bStrides;
655  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
656  failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
657  aStrides.size() != bStrides.size())
658  return false;
659 
660  // Strides along a dimension/offset are compatible if the value in the
661  // source memref is static and the value in the target memref is the
662  // same. They are also compatible if either one is dynamic (see
663  // description of MemRefCastOp for details).
664  auto checkCompatible = [](int64_t a, int64_t b) {
665  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
666  };
667  if (!checkCompatible(aOffset, bOffset))
668  return false;
669  for (const auto &aStride : enumerate(aStrides))
670  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
671  return false;
672  }
673  if (aT.getMemorySpace() != bT.getMemorySpace())
674  return false;
675 
676  // They must have the same rank, and any specified dimensions must match.
677  if (aT.getRank() != bT.getRank())
678  return false;
679 
680  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
681  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
682  if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
683  aDim != bDim)
684  return false;
685  }
686  return true;
687  } else {
688  if (!aT && !uaT)
689  return false;
690  if (!bT && !ubT)
691  return false;
692  // Unranked to unranked casting is unsupported
693  if (uaT && ubT)
694  return false;
695 
696  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
697  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
698  if (aEltType != bEltType)
699  return false;
700 
701  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
702  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
703  return aMemSpace == bMemSpace;
704  }
705 
706  return false;
707 }
708 
709 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
710  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // CopyOp
715 //===----------------------------------------------------------------------===//
716 
717 namespace {
718 
719 /// Fold memref.copy(%x, %x).
720 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
722 
723  LogicalResult matchAndRewrite(CopyOp copyOp,
724  PatternRewriter &rewriter) const override {
725  if (copyOp.getSource() != copyOp.getTarget())
726  return failure();
727 
728  rewriter.eraseOp(copyOp);
729  return success();
730  }
731 };
732 
733 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
735 
736  static bool isEmptyMemRef(BaseMemRefType type) {
737  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
738  }
739 
740  LogicalResult matchAndRewrite(CopyOp copyOp,
741  PatternRewriter &rewriter) const override {
742  if (isEmptyMemRef(copyOp.getSource().getType()) ||
743  isEmptyMemRef(copyOp.getTarget().getType())) {
744  rewriter.eraseOp(copyOp);
745  return success();
746  }
747 
748  return failure();
749  }
750 };
751 } // namespace
752 
753 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
754  MLIRContext *context) {
755  results.add<FoldEmptyCopy, FoldSelfCopy>(context);
756 }
757 
758 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
759 /// and element type, the cast can be skipped. Such CastOps only cast the layout
760 /// of the type.
761 static LogicalResult FoldCopyOfCast(CopyOp op) {
762  for (OpOperand &operand : op->getOpOperands()) {
763  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
764  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
765  operand.set(castOp.getOperand());
766  return success();
767  }
768  }
769  return failure();
770 }
771 
772 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
774 
775  /// copy(memrefcast) -> copy
776  return FoldCopyOfCast(*this);
777 }
778 
779 //===----------------------------------------------------------------------===//
780 // DeallocOp
781 //===----------------------------------------------------------------------===//
782 
783 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
785  /// dealloc(memrefcast) -> dealloc
786  return foldMemRefCast(*this);
787 }
788 
789 //===----------------------------------------------------------------------===//
790 // DimOp
791 //===----------------------------------------------------------------------===//
792 
793 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
794  setNameFn(getResult(), "dim");
795 }
796 
797 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
798  int64_t index) {
799  auto loc = result.location;
800  Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
801  build(builder, result, source, indexValue);
802 }
803 
804 std::optional<int64_t> DimOp::getConstantIndex() {
805  return getConstantIntValue(getIndex());
806 }
807 
808 Speculation::Speculatability DimOp::getSpeculatability() {
809  auto constantIndex = getConstantIndex();
810  if (!constantIndex)
812 
813  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
814  if (!rankedSourceType)
816 
817  if (rankedSourceType.getRank() <= constantIndex)
819 
821 }
822 
823 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
824  SetIntLatticeFn setResultRange) {
825  setResultRange(getResult(),
826  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
827 }
828 
829 /// Return a map with key being elements in `vals` and data being number of
830 /// occurences of it. Use std::map, since the `vals` here are strides and the
831 /// dynamic stride value is the same as the tombstone value for
832 /// `DenseMap<int64_t>`.
833 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
834  std::map<int64_t, unsigned> numOccurences;
835  for (auto val : vals)
836  numOccurences[val]++;
837  return numOccurences;
838 }
839 
840 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
841 /// to be a subset of `originalType` with some `1` entries erased, return the
842 /// set of indices that specifies which of the entries of `originalShape` are
843 /// dropped to obtain `reducedShape`.
844 /// This accounts for cases where there are multiple unit-dims, but only a
845 /// subset of those are dropped. For MemRefTypes these can be disambiguated
846 /// using the strides. If a dimension is dropped the stride must be dropped too.
847 static FailureOr<llvm::SmallBitVector>
848 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
849  ArrayRef<OpFoldResult> sizes) {
850  llvm::SmallBitVector unusedDims(originalType.getRank());
851  if (originalType.getRank() == reducedType.getRank())
852  return unusedDims;
853 
854  for (const auto &dim : llvm::enumerate(sizes))
855  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
856  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
857  unusedDims.set(dim.index());
858 
859  // Early exit for the case where the number of unused dims matches the number
860  // of ranks reduced.
861  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
862  originalType.getRank())
863  return unusedDims;
864 
865  SmallVector<int64_t> originalStrides, candidateStrides;
866  int64_t originalOffset, candidateOffset;
867  if (failed(
868  originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
869  failed(
870  reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
871  return failure();
872 
873  // For memrefs, a dimension is truly dropped if its corresponding stride is
874  // also dropped. This is particularly important when more than one of the dims
875  // is 1. Track the number of occurences of the strides in the original type
876  // and the candidate type. For each unused dim that stride should not be
877  // present in the candidate type. Note that there could be multiple dimensions
878  // that have the same size. We dont need to exactly figure out which dim
879  // corresponds to which stride, we just need to verify that the number of
880  // reptitions of a stride in the original + number of unused dims with that
881  // stride == number of repititions of a stride in the candidate.
882  std::map<int64_t, unsigned> currUnaccountedStrides =
883  getNumOccurences(originalStrides);
884  std::map<int64_t, unsigned> candidateStridesNumOccurences =
885  getNumOccurences(candidateStrides);
886  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
887  if (!unusedDims.test(dim))
888  continue;
889  int64_t originalStride = originalStrides[dim];
890  if (currUnaccountedStrides[originalStride] >
891  candidateStridesNumOccurences[originalStride]) {
892  // This dim can be treated as dropped.
893  currUnaccountedStrides[originalStride]--;
894  continue;
895  }
896  if (currUnaccountedStrides[originalStride] ==
897  candidateStridesNumOccurences[originalStride]) {
898  // The stride for this is not dropped. Keep as is.
899  unusedDims.reset(dim);
900  continue;
901  }
902  if (currUnaccountedStrides[originalStride] <
903  candidateStridesNumOccurences[originalStride]) {
904  // This should never happen. Cant have a stride in the reduced rank type
905  // that wasnt in the original one.
906  return failure();
907  }
908  }
909 
910  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
911  originalType.getRank())
912  return failure();
913  return unusedDims;
914 }
915 
916 llvm::SmallBitVector SubViewOp::getDroppedDims() {
917  MemRefType sourceType = getSourceType();
918  MemRefType resultType = getType();
919  FailureOr<llvm::SmallBitVector> unusedDims =
920  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
921  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
922  return *unusedDims;
923 }
924 
925 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
926  // All forms of folding require a known index.
927  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
928  if (!index)
929  return {};
930 
931  // Folding for unranked types (UnrankedMemRefType) is not supported.
932  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
933  if (!memrefType)
934  return {};
935 
936  // Out of bound indices produce undefined behavior but are still valid IR.
937  // Don't choke on them.
938  int64_t indexVal = index.getInt();
939  if (indexVal < 0 || indexVal >= memrefType.getRank())
940  return {};
941 
942  // Fold if the shape extent along the given index is known.
943  if (!memrefType.isDynamicDim(index.getInt())) {
944  Builder builder(getContext());
945  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
946  }
947 
948  // The size at the given index is now known to be a dynamic size.
949  unsigned unsignedIndex = index.getValue().getZExtValue();
950 
951  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
952  Operation *definingOp = getSource().getDefiningOp();
953 
954  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
955  return *(alloc.getDynamicSizes().begin() +
956  memrefType.getDynamicDimIndex(unsignedIndex));
957 
958  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
959  return *(alloca.getDynamicSizes().begin() +
960  memrefType.getDynamicDimIndex(unsignedIndex));
961 
962  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
963  return *(view.getDynamicSizes().begin() +
964  memrefType.getDynamicDimIndex(unsignedIndex));
965 
966  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
967  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
968  unsigned resultIndex = 0;
969  unsigned sourceRank = subview.getSourceType().getRank();
970  unsigned sourceIndex = 0;
971  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
972  if (unusedDims.test(i))
973  continue;
974  if (resultIndex == unsignedIndex) {
975  sourceIndex = i;
976  break;
977  }
978  resultIndex++;
979  }
980  assert(subview.isDynamicSize(sourceIndex) &&
981  "expected dynamic subview size");
982  return subview.getDynamicSize(sourceIndex);
983  }
984 
985  if (auto sizeInterface =
986  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
987  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
988  "Expected dynamic subview size");
989  return sizeInterface.getDynamicSize(unsignedIndex);
990  }
991 
992  // dim(memrefcast) -> dim
993  if (succeeded(foldMemRefCast(*this)))
994  return getResult();
995 
996  return {};
997 }
998 
999 namespace {
1000 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1001 /// operand.
1002 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1004 
1005  LogicalResult matchAndRewrite(DimOp dim,
1006  PatternRewriter &rewriter) const override {
1007  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1008 
1009  if (!reshape)
1010  return rewriter.notifyMatchFailure(
1011  dim, "Dim op is not defined by a reshape op.");
1012 
1013  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1014  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1015  // cheaply check that either of the following conditions hold:
1016  // 1. dim.getIndex() is defined in the same block as reshape but before
1017  // reshape.
1018  // 2. dim.getIndex() is defined in a parent block of
1019  // reshape.
1020 
1021  // Check condition 1
1022  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1023  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1024  if (reshape->isBeforeInBlock(definingOp)) {
1025  return rewriter.notifyMatchFailure(
1026  dim,
1027  "dim.getIndex is not defined before reshape in the same block.");
1028  }
1029  } // else dim.getIndex is a block argument to reshape->getBlock and
1030  // dominates reshape
1031  } // Check condition 2
1032  else if (dim->getBlock() != reshape->getBlock() &&
1033  !dim.getIndex().getParentRegion()->isProperAncestor(
1034  reshape->getParentRegion())) {
1035  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1036  // already know dim.getIndex() dominates reshape without calling
1037  // `isProperAncestor`
1038  return rewriter.notifyMatchFailure(
1039  dim, "dim.getIndex does not dominate reshape.");
1040  }
1041 
1042  // Place the load directly after the reshape to ensure that the shape memref
1043  // was not mutated.
1044  rewriter.setInsertionPointAfter(reshape);
1045  Location loc = dim.getLoc();
1046  Value load =
1047  LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1048  if (load.getType() != dim.getType())
1049  load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1050  rewriter.replaceOp(dim, load);
1051  return success();
1052  }
1053 };
1054 
1055 } // namespace
1056 
1057 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1058  MLIRContext *context) {
1059  results.add<DimOfMemRefReshape>(context);
1060 }
1061 
1062 // ---------------------------------------------------------------------------
1063 // DmaStartOp
1064 // ---------------------------------------------------------------------------
1065 
1066 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1067  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1068  ValueRange destIndices, Value numElements,
1069  Value tagMemRef, ValueRange tagIndices, Value stride,
1070  Value elementsPerStride) {
1071  result.addOperands(srcMemRef);
1072  result.addOperands(srcIndices);
1073  result.addOperands(destMemRef);
1074  result.addOperands(destIndices);
1075  result.addOperands({numElements, tagMemRef});
1076  result.addOperands(tagIndices);
1077  if (stride)
1078  result.addOperands({stride, elementsPerStride});
1079 }
1080 
1082  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1083  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1084  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1085  if (isStrided())
1086  p << ", " << getStride() << ", " << getNumElementsPerStride();
1087 
1088  p.printOptionalAttrDict((*this)->getAttrs());
1089  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1090  << ", " << getTagMemRef().getType();
1091 }
1092 
1093 // Parse DmaStartOp.
1094 // Ex:
1095 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1096 // %tag[%index], %stride, %num_elt_per_stride :
1097 // : memref<3076 x f32, 0>,
1098 // memref<1024 x f32, 2>,
1099 // memref<1 x i32>
1100 //
1101 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1102  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1104  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1106  OpAsmParser::UnresolvedOperand numElementsInfo;
1107  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1110 
1111  SmallVector<Type, 3> types;
1112  auto indexType = parser.getBuilder().getIndexType();
1113 
1114  // Parse and resolve the following list of operands:
1115  // *) source memref followed by its indices (in square brackets).
1116  // *) destination memref followed by its indices (in square brackets).
1117  // *) dma size in KiB.
1118  if (parser.parseOperand(srcMemRefInfo) ||
1119  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1120  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1121  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1122  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1123  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1124  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1125  return failure();
1126 
1127  // Parse optional stride and elements per stride.
1128  if (parser.parseTrailingOperandList(strideInfo))
1129  return failure();
1130 
1131  bool isStrided = strideInfo.size() == 2;
1132  if (!strideInfo.empty() && !isStrided) {
1133  return parser.emitError(parser.getNameLoc(),
1134  "expected two stride related operands");
1135  }
1136 
1137  if (parser.parseColonTypeList(types))
1138  return failure();
1139  if (types.size() != 3)
1140  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1141 
1142  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1143  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1144  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1145  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1146  // size should be an index.
1147  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1148  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1149  // tag indices should be index.
1150  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1151  return failure();
1152 
1153  if (isStrided) {
1154  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1155  return failure();
1156  }
1157 
1158  return success();
1159 }
1160 
1161 LogicalResult DmaStartOp::verify() {
1162  unsigned numOperands = getNumOperands();
1163 
1164  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1165  // the number of elements.
1166  if (numOperands < 4)
1167  return emitOpError("expected at least 4 operands");
1168 
1169  // Check types of operands. The order of these calls is important: the later
1170  // calls rely on some type properties to compute the operand position.
1171  // 1. Source memref.
1172  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1173  return emitOpError("expected source to be of memref type");
1174  if (numOperands < getSrcMemRefRank() + 4)
1175  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1176  << " operands";
1177  if (!getSrcIndices().empty() &&
1178  !llvm::all_of(getSrcIndices().getTypes(),
1179  [](Type t) { return t.isIndex(); }))
1180  return emitOpError("expected source indices to be of index type");
1181 
1182  // 2. Destination memref.
1183  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1184  return emitOpError("expected destination to be of memref type");
1185  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1186  if (numOperands < numExpectedOperands)
1187  return emitOpError() << "expected at least " << numExpectedOperands
1188  << " operands";
1189  if (!getDstIndices().empty() &&
1190  !llvm::all_of(getDstIndices().getTypes(),
1191  [](Type t) { return t.isIndex(); }))
1192  return emitOpError("expected destination indices to be of index type");
1193 
1194  // 3. Number of elements.
1195  if (!getNumElements().getType().isIndex())
1196  return emitOpError("expected num elements to be of index type");
1197 
1198  // 4. Tag memref.
1199  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1200  return emitOpError("expected tag to be of memref type");
1201  numExpectedOperands += getTagMemRefRank();
1202  if (numOperands < numExpectedOperands)
1203  return emitOpError() << "expected at least " << numExpectedOperands
1204  << " operands";
1205  if (!getTagIndices().empty() &&
1206  !llvm::all_of(getTagIndices().getTypes(),
1207  [](Type t) { return t.isIndex(); }))
1208  return emitOpError("expected tag indices to be of index type");
1209 
1210  // Optional stride-related operands must be either both present or both
1211  // absent.
1212  if (numOperands != numExpectedOperands &&
1213  numOperands != numExpectedOperands + 2)
1214  return emitOpError("incorrect number of operands");
1215 
1216  // 5. Strides.
1217  if (isStrided()) {
1218  if (!getStride().getType().isIndex() ||
1219  !getNumElementsPerStride().getType().isIndex())
1220  return emitOpError(
1221  "expected stride and num elements per stride to be of type index");
1222  }
1223 
1224  return success();
1225 }
1226 
1227 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1228  SmallVectorImpl<OpFoldResult> &results) {
1229  /// dma_start(memrefcast) -> dma_start
1230  return foldMemRefCast(*this);
1231 }
1232 
1233 // ---------------------------------------------------------------------------
1234 // DmaWaitOp
1235 // ---------------------------------------------------------------------------
1236 
1237 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1238  SmallVectorImpl<OpFoldResult> &results) {
1239  /// dma_wait(memrefcast) -> dma_wait
1240  return foldMemRefCast(*this);
1241 }
1242 
1243 LogicalResult DmaWaitOp::verify() {
1244  // Check that the number of tag indices matches the tagMemRef rank.
1245  unsigned numTagIndices = getTagIndices().size();
1246  unsigned tagMemRefRank = getTagMemRefRank();
1247  if (numTagIndices != tagMemRefRank)
1248  return emitOpError() << "expected tagIndices to have the same number of "
1249  "elements as the tagMemRef rank, expected "
1250  << tagMemRefRank << ", but got " << numTagIndices;
1251  return success();
1252 }
1253 
1254 //===----------------------------------------------------------------------===//
1255 // ExtractAlignedPointerAsIndexOp
1256 //===----------------------------------------------------------------------===//
1257 
1258 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1259  function_ref<void(Value, StringRef)> setNameFn) {
1260  setNameFn(getResult(), "intptr");
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // ExtractStridedMetadataOp
1265 //===----------------------------------------------------------------------===//
1266 
1267 /// The number and type of the results are inferred from the
1268 /// shape of the source.
1269 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1270  MLIRContext *context, std::optional<Location> location,
1271  ExtractStridedMetadataOp::Adaptor adaptor,
1272  SmallVectorImpl<Type> &inferredReturnTypes) {
1273  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1274  if (!sourceType)
1275  return failure();
1276 
1277  unsigned sourceRank = sourceType.getRank();
1278  IndexType indexType = IndexType::get(context);
1279  auto memrefType =
1280  MemRefType::get({}, sourceType.getElementType(),
1281  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1282  // Base.
1283  inferredReturnTypes.push_back(memrefType);
1284  // Offset.
1285  inferredReturnTypes.push_back(indexType);
1286  // Sizes and strides.
1287  for (unsigned i = 0; i < sourceRank * 2; ++i)
1288  inferredReturnTypes.push_back(indexType);
1289  return success();
1290 }
1291 
1292 void ExtractStridedMetadataOp::getAsmResultNames(
1293  function_ref<void(Value, StringRef)> setNameFn) {
1294  setNameFn(getBaseBuffer(), "base_buffer");
1295  setNameFn(getOffset(), "offset");
1296  // For multi-result to work properly with pretty names and packed syntax `x:3`
1297  // we can only give a pretty name to the first value in the pack.
1298  if (!getSizes().empty()) {
1299  setNameFn(getSizes().front(), "sizes");
1300  setNameFn(getStrides().front(), "strides");
1301  }
1302 }
1303 
1304 /// Helper function to perform the replacement of all constant uses of `values`
1305 /// by a materialized constant extracted from `maybeConstants`.
1306 /// `values` and `maybeConstants` are expected to have the same size.
1307 template <typename Container>
1308 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1309  Container values,
1310  ArrayRef<OpFoldResult> maybeConstants) {
1311  assert(values.size() == maybeConstants.size() &&
1312  " expected values and maybeConstants of the same size");
1313  bool atLeastOneReplacement = false;
1314  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1315  // Don't materialize a constant if there are no uses: this would indice
1316  // infinite loops in the driver.
1317  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1318  continue;
1319  assert(isa<Attribute>(maybeConstant) &&
1320  "The constified value should be either unchanged (i.e., == result) "
1321  "or a constant");
1322  Value constantVal = arith::ConstantIndexOp::create(
1323  rewriter, loc,
1324  llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1325  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1326  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1327  // yet.
1328  op->replaceUsesOfWith(result, constantVal);
1329  atLeastOneReplacement = true;
1330  }
1331  }
1332  return atLeastOneReplacement;
1333 }
1334 
1335 LogicalResult
1336 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1337  SmallVectorImpl<OpFoldResult> &results) {
1338  OpBuilder builder(*this);
1339 
1340  bool atLeastOneReplacement = replaceConstantUsesOf(
1341  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1342  getConstifiedMixedOffset());
1343  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1344  getConstifiedMixedSizes());
1345  atLeastOneReplacement |= replaceConstantUsesOf(
1346  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1347 
1348  return success(atLeastOneReplacement);
1349 }
1350 
1351 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1352  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1353  constifyIndexValues(values, getSource().getType().getShape());
1354  return values;
1355 }
1356 
1358 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1359  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1360  SmallVector<int64_t> staticValues;
1361  int64_t unused;
1362  LogicalResult status =
1363  getSource().getType().getStridesAndOffset(staticValues, unused);
1364  (void)status;
1365  assert(succeeded(status) && "could not get strides from type");
1366  constifyIndexValues(values, staticValues);
1367  return values;
1368 }
1369 
1370 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1371  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1372  SmallVector<OpFoldResult> values(1, offsetOfr);
1373  SmallVector<int64_t> staticValues, unused;
1374  int64_t offset;
1375  LogicalResult status =
1376  getSource().getType().getStridesAndOffset(unused, offset);
1377  (void)status;
1378  assert(succeeded(status) && "could not get offset from type");
1379  staticValues.push_back(offset);
1380  constifyIndexValues(values, staticValues);
1381  return values[0];
1382 }
1383 
1384 //===----------------------------------------------------------------------===//
1385 // GenericAtomicRMWOp
1386 //===----------------------------------------------------------------------===//
1387 
1388 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1389  Value memref, ValueRange ivs) {
1390  OpBuilder::InsertionGuard g(builder);
1391  result.addOperands(memref);
1392  result.addOperands(ivs);
1393 
1394  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1395  Type elementType = memrefType.getElementType();
1396  result.addTypes(elementType);
1397 
1398  Region *bodyRegion = result.addRegion();
1399  builder.createBlock(bodyRegion);
1400  bodyRegion->addArgument(elementType, memref.getLoc());
1401  }
1402 }
1403 
1404 LogicalResult GenericAtomicRMWOp::verify() {
1405  auto &body = getRegion();
1406  if (body.getNumArguments() != 1)
1407  return emitOpError("expected single number of entry block arguments");
1408 
1409  if (getResult().getType() != body.getArgument(0).getType())
1410  return emitOpError("expected block argument of the same type result type");
1411 
1412  bool hasSideEffects =
1413  body.walk([&](Operation *nestedOp) {
1414  if (isMemoryEffectFree(nestedOp))
1415  return WalkResult::advance();
1416  nestedOp->emitError(
1417  "body of 'memref.generic_atomic_rmw' should contain "
1418  "only operations with no side effects");
1419  return WalkResult::interrupt();
1420  })
1421  .wasInterrupted();
1422  return hasSideEffects ? failure() : success();
1423 }
1424 
1425 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1426  OperationState &result) {
1428  Type memrefType;
1430 
1431  Type indexType = parser.getBuilder().getIndexType();
1432  if (parser.parseOperand(memref) ||
1434  parser.parseColonType(memrefType) ||
1435  parser.resolveOperand(memref, memrefType, result.operands) ||
1436  parser.resolveOperands(ivs, indexType, result.operands))
1437  return failure();
1438 
1439  Region *body = result.addRegion();
1440  if (parser.parseRegion(*body, {}) ||
1441  parser.parseOptionalAttrDict(result.attributes))
1442  return failure();
1443  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1444  return success();
1445 }
1446 
1448  p << ' ' << getMemref() << "[" << getIndices()
1449  << "] : " << getMemref().getType() << ' ';
1450  p.printRegion(getRegion());
1451  p.printOptionalAttrDict((*this)->getAttrs());
1452 }
1453 
1454 //===----------------------------------------------------------------------===//
1455 // AtomicYieldOp
1456 //===----------------------------------------------------------------------===//
1457 
1458 LogicalResult AtomicYieldOp::verify() {
1459  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1460  Type resultType = getResult().getType();
1461  if (parentType != resultType)
1462  return emitOpError() << "types mismatch between yield op: " << resultType
1463  << " and its parent: " << parentType;
1464  return success();
1465 }
1466 
1467 //===----------------------------------------------------------------------===//
1468 // GlobalOp
1469 //===----------------------------------------------------------------------===//
1470 
1472  TypeAttr type,
1473  Attribute initialValue) {
1474  p << type;
1475  if (!op.isExternal()) {
1476  p << " = ";
1477  if (op.isUninitialized())
1478  p << "uninitialized";
1479  else
1480  p.printAttributeWithoutType(initialValue);
1481  }
1482 }
1483 
1484 static ParseResult
1486  Attribute &initialValue) {
1487  Type type;
1488  if (parser.parseType(type))
1489  return failure();
1490 
1491  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1492  if (!memrefType || !memrefType.hasStaticShape())
1493  return parser.emitError(parser.getNameLoc())
1494  << "type should be static shaped memref, but got " << type;
1495  typeAttr = TypeAttr::get(type);
1496 
1497  if (parser.parseOptionalEqual())
1498  return success();
1499 
1500  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1501  initialValue = UnitAttr::get(parser.getContext());
1502  return success();
1503  }
1504 
1505  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1506  if (parser.parseAttribute(initialValue, tensorType))
1507  return failure();
1508  if (!llvm::isa<ElementsAttr>(initialValue))
1509  return parser.emitError(parser.getNameLoc())
1510  << "initial value should be a unit or elements attribute";
1511  return success();
1512 }
1513 
1514 LogicalResult GlobalOp::verify() {
1515  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1516  if (!memrefType || !memrefType.hasStaticShape())
1517  return emitOpError("type should be static shaped memref, but got ")
1518  << getType();
1519 
1520  // Verify that the initial value, if present, is either a unit attribute or
1521  // an elements attribute.
1522  if (getInitialValue().has_value()) {
1523  Attribute initValue = getInitialValue().value();
1524  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1525  return emitOpError("initial value should be a unit or elements "
1526  "attribute, but got ")
1527  << initValue;
1528 
1529  // Check that the type of the initial value is compatible with the type of
1530  // the global variable.
1531  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1532  // Check the element types match.
1533  auto initElementType =
1534  cast<TensorType>(elementsAttr.getType()).getElementType();
1535  auto memrefElementType = memrefType.getElementType();
1536 
1537  if (initElementType != memrefElementType)
1538  return emitOpError("initial value element expected to be of type ")
1539  << memrefElementType << ", but was of type " << initElementType;
1540 
1541  // Check the shapes match, given that memref globals can only produce
1542  // statically shaped memrefs and elements literal type must have a static
1543  // shape we can assume both types are shaped.
1544  auto initShape = elementsAttr.getShapedType().getShape();
1545  auto memrefShape = memrefType.getShape();
1546  if (initShape != memrefShape)
1547  return emitOpError("initial value shape expected to be ")
1548  << memrefShape << " but was " << initShape;
1549  }
1550  }
1551 
1552  // TODO: verify visibility for declarations.
1553  return success();
1554 }
1555 
1556 ElementsAttr GlobalOp::getConstantInitValue() {
1557  auto initVal = getInitialValue();
1558  if (getConstant() && initVal.has_value())
1559  return llvm::cast<ElementsAttr>(initVal.value());
1560  return {};
1561 }
1562 
1563 //===----------------------------------------------------------------------===//
1564 // GetGlobalOp
1565 //===----------------------------------------------------------------------===//
1566 
1567 LogicalResult
1568 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1569  // Verify that the result type is same as the type of the referenced
1570  // memref.global op.
1571  auto global =
1572  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1573  if (!global)
1574  return emitOpError("'")
1575  << getName() << "' does not reference a valid global memref";
1576 
1577  Type resultType = getResult().getType();
1578  if (global.getType() != resultType)
1579  return emitOpError("result type ")
1580  << resultType << " does not match type " << global.getType()
1581  << " of the global memref @" << getName();
1582  return success();
1583 }
1584 
1585 //===----------------------------------------------------------------------===//
1586 // LoadOp
1587 //===----------------------------------------------------------------------===//
1588 
1589 LogicalResult LoadOp::verify() {
1590  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1591  return emitOpError("incorrect number of indices for load, expected ")
1592  << getMemRefType().getRank() << " but got " << getIndices().size();
1593  }
1594  return success();
1595 }
1596 
1597 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1598  /// load(memrefcast) -> load
1599  if (succeeded(foldMemRefCast(*this)))
1600  return getResult();
1601  return OpFoldResult();
1602 }
1603 
1604 //===----------------------------------------------------------------------===//
1605 // MemorySpaceCastOp
1606 //===----------------------------------------------------------------------===//
1607 
1608 void MemorySpaceCastOp::getAsmResultNames(
1609  function_ref<void(Value, StringRef)> setNameFn) {
1610  setNameFn(getResult(), "memspacecast");
1611 }
1612 
1613 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1614  if (inputs.size() != 1 || outputs.size() != 1)
1615  return false;
1616  Type a = inputs.front(), b = outputs.front();
1617  auto aT = llvm::dyn_cast<MemRefType>(a);
1618  auto bT = llvm::dyn_cast<MemRefType>(b);
1619 
1620  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1621  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1622 
1623  if (aT && bT) {
1624  if (aT.getElementType() != bT.getElementType())
1625  return false;
1626  if (aT.getLayout() != bT.getLayout())
1627  return false;
1628  if (aT.getShape() != bT.getShape())
1629  return false;
1630  return true;
1631  }
1632  if (uaT && ubT) {
1633  return uaT.getElementType() == ubT.getElementType();
1634  }
1635  return false;
1636 }
1637 
1638 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1639  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1640  // t2)
1641  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1642  getSourceMutable().assign(parentCast.getSource());
1643  return getResult();
1644  }
1645  return Value{};
1646 }
1647 
1648 //===----------------------------------------------------------------------===//
1649 // PrefetchOp
1650 //===----------------------------------------------------------------------===//
1651 
1653  p << " " << getMemref() << '[';
1655  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1656  p << ", locality<" << getLocalityHint();
1657  p << ">, " << (getIsDataCache() ? "data" : "instr");
1659  (*this)->getAttrs(),
1660  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1661  p << " : " << getMemRefType();
1662 }
1663 
1664 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1665  OpAsmParser::UnresolvedOperand memrefInfo;
1667  IntegerAttr localityHint;
1668  MemRefType type;
1669  StringRef readOrWrite, cacheType;
1670 
1671  auto indexTy = parser.getBuilder().getIndexType();
1672  auto i32Type = parser.getBuilder().getIntegerType(32);
1673  if (parser.parseOperand(memrefInfo) ||
1674  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1675  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1676  parser.parseComma() || parser.parseKeyword("locality") ||
1677  parser.parseLess() ||
1678  parser.parseAttribute(localityHint, i32Type, "localityHint",
1679  result.attributes) ||
1680  parser.parseGreater() || parser.parseComma() ||
1681  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1682  parser.resolveOperand(memrefInfo, type, result.operands) ||
1683  parser.resolveOperands(indexInfo, indexTy, result.operands))
1684  return failure();
1685 
1686  if (readOrWrite != "read" && readOrWrite != "write")
1687  return parser.emitError(parser.getNameLoc(),
1688  "rw specifier has to be 'read' or 'write'");
1689  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1690  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1691 
1692  if (cacheType != "data" && cacheType != "instr")
1693  return parser.emitError(parser.getNameLoc(),
1694  "cache type has to be 'data' or 'instr'");
1695 
1696  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1697  parser.getBuilder().getBoolAttr(cacheType == "data"));
1698 
1699  return success();
1700 }
1701 
1702 LogicalResult PrefetchOp::verify() {
1703  if (getNumOperands() != 1 + getMemRefType().getRank())
1704  return emitOpError("too few indices");
1705 
1706  return success();
1707 }
1708 
1709 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1710  SmallVectorImpl<OpFoldResult> &results) {
1711  // prefetch(memrefcast) -> prefetch
1712  return foldMemRefCast(*this);
1713 }
1714 
1715 //===----------------------------------------------------------------------===//
1716 // RankOp
1717 //===----------------------------------------------------------------------===//
1718 
1719 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1720  // Constant fold rank when the rank of the operand is known.
1721  auto type = getOperand().getType();
1722  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1723  if (shapedType && shapedType.hasRank())
1724  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1725  return IntegerAttr();
1726 }
1727 
1728 //===----------------------------------------------------------------------===//
1729 // ReinterpretCastOp
1730 //===----------------------------------------------------------------------===//
1731 
1732 void ReinterpretCastOp::getAsmResultNames(
1733  function_ref<void(Value, StringRef)> setNameFn) {
1734  setNameFn(getResult(), "reinterpret_cast");
1735 }
1736 
1737 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1738 /// `staticSizes` and `staticStrides` are automatically filled with
1739 /// source-memref-rank sentinel values that encode dynamic entries.
1740 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1741  MemRefType resultType, Value source,
1742  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1743  ArrayRef<OpFoldResult> strides,
1744  ArrayRef<NamedAttribute> attrs) {
1745  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1746  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1747  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1748  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1749  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1750  result.addAttributes(attrs);
1751  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1752  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1753  b.getDenseI64ArrayAttr(staticSizes),
1754  b.getDenseI64ArrayAttr(staticStrides));
1755 }
1756 
1757 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1758  Value source, OpFoldResult offset,
1759  ArrayRef<OpFoldResult> sizes,
1760  ArrayRef<OpFoldResult> strides,
1761  ArrayRef<NamedAttribute> attrs) {
1762  auto sourceType = cast<BaseMemRefType>(source.getType());
1763  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1764  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1765  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1766  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1767  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1768  auto stridedLayout = StridedLayoutAttr::get(
1769  b.getContext(), staticOffsets.front(), staticStrides);
1770  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1771  stridedLayout, sourceType.getMemorySpace());
1772  build(b, result, resultType, source, offset, sizes, strides, attrs);
1773 }
1774 
1775 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1776  MemRefType resultType, Value source,
1777  int64_t offset, ArrayRef<int64_t> sizes,
1778  ArrayRef<int64_t> strides,
1779  ArrayRef<NamedAttribute> attrs) {
1780  SmallVector<OpFoldResult> sizeValues =
1781  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1782  return b.getI64IntegerAttr(v);
1783  }));
1784  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1785  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1786  return b.getI64IntegerAttr(v);
1787  }));
1788  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1789  strideValues, attrs);
1790 }
1791 
1792 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1793  MemRefType resultType, Value source, Value offset,
1794  ValueRange sizes, ValueRange strides,
1795  ArrayRef<NamedAttribute> attrs) {
1796  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1797  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1798  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1799  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1800  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1801 }
1802 
1803 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1804 // completed automatically, like we have for subview and extract_slice.
1805 LogicalResult ReinterpretCastOp::verify() {
1806  // The source and result memrefs should be in the same memory space.
1807  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1808  auto resultType = llvm::cast<MemRefType>(getType());
1809  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1810  return emitError("different memory spaces specified for source type ")
1811  << srcType << " and result memref type " << resultType;
1812  if (srcType.getElementType() != resultType.getElementType())
1813  return emitError("different element types specified for source type ")
1814  << srcType << " and result memref type " << resultType;
1815 
1816  // Match sizes in result memref type and in static_sizes attribute.
1817  for (auto [idx, resultSize, expectedSize] :
1818  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1819  if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1820  return emitError("expected result type with size = ")
1821  << (ShapedType::isDynamic(expectedSize)
1822  ? std::string("dynamic")
1823  : std::to_string(expectedSize))
1824  << " instead of " << resultSize << " in dim = " << idx;
1825  }
1826 
1827  // Match offset and strides in static_offset and static_strides attributes. If
1828  // result memref type has no affine map specified, this will assume an
1829  // identity layout.
1830  int64_t resultOffset;
1831  SmallVector<int64_t, 4> resultStrides;
1832  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1833  return emitError("expected result type to have strided layout but found ")
1834  << resultType;
1835 
1836  // Match offset in result memref type and in static_offsets attribute.
1837  int64_t expectedOffset = getStaticOffsets().front();
1838  if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1839  return emitError("expected result type with offset = ")
1840  << (ShapedType::isDynamic(expectedOffset)
1841  ? std::string("dynamic")
1842  : std::to_string(expectedOffset))
1843  << " instead of " << resultOffset;
1844 
1845  // Match strides in result memref type and in static_strides attribute.
1846  for (auto [idx, resultStride, expectedStride] :
1847  llvm::enumerate(resultStrides, getStaticStrides())) {
1848  if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1849  return emitError("expected result type with stride = ")
1850  << (ShapedType::isDynamic(expectedStride)
1851  ? std::string("dynamic")
1852  : std::to_string(expectedStride))
1853  << " instead of " << resultStride << " in dim = " << idx;
1854  }
1855 
1856  return success();
1857 }
1858 
1859 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1860  Value src = getSource();
1861  auto getPrevSrc = [&]() -> Value {
1862  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1863  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1864  return prev.getSource();
1865 
1866  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1867  if (auto prev = src.getDefiningOp<CastOp>())
1868  return prev.getSource();
1869 
1870  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1871  // are 0.
1872  if (auto prev = src.getDefiningOp<SubViewOp>())
1873  if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
1874  return prev.getSource();
1875 
1876  return nullptr;
1877  };
1878 
1879  if (auto prevSrc = getPrevSrc()) {
1880  getSourceMutable().assign(prevSrc);
1881  return getResult();
1882  }
1883 
1884  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1885  if (ShapedType::isStaticShape(getType().getShape()) &&
1886  src.getType() == getType() && getStaticOffsets().front() == 0) {
1887  return src;
1888  }
1889 
1890  return nullptr;
1891 }
1892 
1893 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1895  constifyIndexValues(values, getType().getShape());
1896  return values;
1897 }
1898 
1899 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1900  SmallVector<OpFoldResult> values = getMixedStrides();
1901  SmallVector<int64_t> staticValues;
1902  int64_t unused;
1903  LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
1904  (void)status;
1905  assert(succeeded(status) && "could not get strides from type");
1906  constifyIndexValues(values, staticValues);
1907  return values;
1908 }
1909 
1910 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1911  SmallVector<OpFoldResult> values = getMixedOffsets();
1912  assert(values.size() == 1 &&
1913  "reinterpret_cast must have one and only one offset");
1914  SmallVector<int64_t> staticValues, unused;
1915  int64_t offset;
1916  LogicalResult status = getType().getStridesAndOffset(unused, offset);
1917  (void)status;
1918  assert(succeeded(status) && "could not get offset from type");
1919  staticValues.push_back(offset);
1920  constifyIndexValues(values, staticValues);
1921  return values[0];
1922 }
1923 
1924 namespace {
1925 /// Replace the sequence:
1926 /// ```
1927 /// base, offset, sizes, strides = extract_strided_metadata src
1928 /// dst = reinterpret_cast base to offset, sizes, strides
1929 /// ```
1930 /// With
1931 ///
1932 /// ```
1933 /// dst = memref.cast src
1934 /// ```
1935 ///
1936 /// Note: The cast operation is only inserted when the type of dst and src
1937 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
1938 ///
1939 /// This pattern also matches when the offset, sizes, and strides don't come
1940 /// directly from the `extract_strided_metadata`'s results but it can be
1941 /// statically proven that they would hold the same values.
1942 ///
1943 /// For instance, the following sequence would be replaced:
1944 /// ```
1945 /// base, offset, sizes, strides =
1946 /// extract_strided_metadata memref : memref<3x4xty>
1947 /// dst = reinterpret_cast base to 0, [3, 4], strides
1948 /// ```
1949 /// Because we know (thanks to the type of the input memref) that variable
1950 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
1951 ///
1952 /// Similarly, the following sequence would be replaced:
1953 /// ```
1954 /// c0 = arith.constant 0
1955 /// c4 = arith.constant 4
1956 /// base, offset, sizes, strides =
1957 /// extract_strided_metadata memref : memref<3x4xty>
1958 /// dst = reinterpret_cast base to c0, [3, c4], strides
1959 /// ```
1960 /// Because we know that `offset`and `c0` will hold 0
1961 /// and `c4` will hold 4.
1962 ///
1963 /// If the pattern above does not match, the input of the
1964 /// extract_strided_metadata is always folded into the input of the
1965 /// reinterpret_cast operator. This allows for dead code elimination to get rid
1966 /// of the extract_strided_metadata in some cases.
1967 struct ReinterpretCastOpExtractStridedMetadataFolder
1968  : public OpRewritePattern<ReinterpretCastOp> {
1969 public:
1971 
1972  LogicalResult matchAndRewrite(ReinterpretCastOp op,
1973  PatternRewriter &rewriter) const override {
1974  auto extractStridedMetadata =
1975  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
1976  if (!extractStridedMetadata)
1977  return failure();
1978 
1979  // Check if the reinterpret cast reconstructs a memref with the exact same
1980  // properties as the extract strided metadata.
1981  auto isReinterpretCastNoop = [&]() -> bool {
1982  // First, check that the strides are the same.
1983  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
1984  op.getConstifiedMixedStrides()))
1985  return false;
1986 
1987  // Second, check the sizes.
1988  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
1989  op.getConstifiedMixedSizes()))
1990  return false;
1991 
1992  // Finally, check the offset.
1993  assert(op.getMixedOffsets().size() == 1 &&
1994  "reinterpret_cast with more than one offset should have been "
1995  "rejected by the verifier");
1996  return extractStridedMetadata.getConstifiedMixedOffset() ==
1997  op.getConstifiedMixedOffset();
1998  };
1999 
2000  if (!isReinterpretCastNoop()) {
2001  // If the extract_strided_metadata / reinterpret_cast pair can't be
2002  // completely folded, then we could fold the input of the
2003  // extract_strided_metadata into the input of the reinterpret_cast
2004  // input. For some cases (e.g., static dimensions) the
2005  // the extract_strided_metadata is eliminated by dead code elimination.
2006  //
2007  // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2008  //
2009  // We can always fold the input of a extract_strided_metadata operator
2010  // to the input of a reinterpret_cast operator, because they point to
2011  // the same memory. Note that the reinterpret_cast does not use the
2012  // layout of its input memref, only its base memory pointer which is
2013  // the same as the base pointer returned by the extract_strided_metadata
2014  // operator and the base pointer of the extract_strided_metadata memref
2015  // input.
2016  rewriter.modifyOpInPlace(op, [&]() {
2017  op.getSourceMutable().assign(extractStridedMetadata.getSource());
2018  });
2019  return success();
2020  }
2021 
2022  // At this point, we know that the back and forth between extract strided
2023  // metadata and reinterpret cast is a noop. However, the final type of the
2024  // reinterpret cast may not be exactly the same as the original memref.
2025  // E.g., it could be changing a dimension from static to dynamic. Check that
2026  // here and add a cast if necessary.
2027  Type srcTy = extractStridedMetadata.getSource().getType();
2028  if (srcTy == op.getResult().getType())
2029  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2030  else
2031  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2032  extractStridedMetadata.getSource());
2033 
2034  return success();
2035  }
2036 };
2037 } // namespace
2038 
2039 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2040  MLIRContext *context) {
2041  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2042 }
2043 
2044 //===----------------------------------------------------------------------===//
2045 // Reassociative reshape ops
2046 //===----------------------------------------------------------------------===//
2047 
2048 void CollapseShapeOp::getAsmResultNames(
2049  function_ref<void(Value, StringRef)> setNameFn) {
2050  setNameFn(getResult(), "collapse_shape");
2051 }
2052 
2053 void ExpandShapeOp::getAsmResultNames(
2054  function_ref<void(Value, StringRef)> setNameFn) {
2055  setNameFn(getResult(), "expand_shape");
2056 }
2057 
2058 LogicalResult ExpandShapeOp::reifyResultShapes(
2059  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2060  reifiedResultShapes = {
2061  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2062  return success();
2063 }
2064 
2065 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2066 /// result and operand. Layout maps are verified separately.
2067 ///
2068 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2069 /// allowed in a reassocation group.
2070 static LogicalResult
2072  ArrayRef<int64_t> expandedShape,
2073  ArrayRef<ReassociationIndices> reassociation,
2074  bool allowMultipleDynamicDimsPerGroup) {
2075  // There must be one reassociation group per collapsed dimension.
2076  if (collapsedShape.size() != reassociation.size())
2077  return op->emitOpError("invalid number of reassociation groups: found ")
2078  << reassociation.size() << ", expected " << collapsedShape.size();
2079 
2080  // The next expected expanded dimension index (while iterating over
2081  // reassociation indices).
2082  int64_t nextDim = 0;
2083  for (const auto &it : llvm::enumerate(reassociation)) {
2084  ReassociationIndices group = it.value();
2085  int64_t collapsedDim = it.index();
2086 
2087  bool foundDynamic = false;
2088  for (int64_t expandedDim : group) {
2089  if (expandedDim != nextDim++)
2090  return op->emitOpError("reassociation indices must be contiguous");
2091 
2092  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2093  return op->emitOpError("reassociation index ")
2094  << expandedDim << " is out of bounds";
2095 
2096  // Check if there are multiple dynamic dims in a reassociation group.
2097  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2098  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2099  return op->emitOpError(
2100  "at most one dimension in a reassociation group may be dynamic");
2101  foundDynamic = true;
2102  }
2103  }
2104 
2105  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2106  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2107  return op->emitOpError("collapsed dim (")
2108  << collapsedDim
2109  << ") must be dynamic if and only if reassociation group is "
2110  "dynamic";
2111 
2112  // If all dims in the reassociation group are static, the size of the
2113  // collapsed dim can be verified.
2114  if (!foundDynamic) {
2115  int64_t groupSize = 1;
2116  for (int64_t expandedDim : group)
2117  groupSize *= expandedShape[expandedDim];
2118  if (groupSize != collapsedShape[collapsedDim])
2119  return op->emitOpError("collapsed dim size (")
2120  << collapsedShape[collapsedDim]
2121  << ") must equal reassociation group size (" << groupSize << ")";
2122  }
2123  }
2124 
2125  if (collapsedShape.empty()) {
2126  // Rank 0: All expanded dimensions must be 1.
2127  for (int64_t d : expandedShape)
2128  if (d != 1)
2129  return op->emitOpError(
2130  "rank 0 memrefs can only be extended/collapsed with/from ones");
2131  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2132  // Rank >= 1: Number of dimensions among all reassociation groups must match
2133  // the result memref rank.
2134  return op->emitOpError("expanded rank (")
2135  << expandedShape.size()
2136  << ") inconsistent with number of reassociation indices (" << nextDim
2137  << ")";
2138  }
2139 
2140  return success();
2141 }
2142 
2143 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2144  return getSymbolLessAffineMaps(getReassociationExprs());
2145 }
2146 
2147 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2149  getReassociationIndices());
2150 }
2151 
2152 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2153  return getSymbolLessAffineMaps(getReassociationExprs());
2154 }
2155 
2156 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2158  getReassociationIndices());
2159 }
2160 
2161 /// Compute the layout map after expanding a given source MemRef type with the
2162 /// specified reassociation indices.
2163 static FailureOr<StridedLayoutAttr>
2164 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2165  ArrayRef<ReassociationIndices> reassociation) {
2166  int64_t srcOffset;
2167  SmallVector<int64_t> srcStrides;
2168  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2169  return failure();
2170  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2171 
2172  // 1-1 mapping between srcStrides and reassociation packs.
2173  // Each srcStride starts with the given value and gets expanded according to
2174  // the proper entries in resultShape.
2175  // Example:
2176  // srcStrides = [10000, 1 , 100 ],
2177  // reassociations = [ [0], [1], [2, 3, 4]],
2178  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2179  // -> For the purpose of stride calculation, the useful sizes are:
2180  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2181  // resultStrides = [10000, 1, 600, 200, 100]
2182  // Note that a stride does not get expanded along the first entry of each
2183  // shape pack.
2184  SmallVector<int64_t> reverseResultStrides;
2185  reverseResultStrides.reserve(resultShape.size());
2186  unsigned shapeIndex = resultShape.size() - 1;
2187  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2188  ReassociationIndices reassoc = std::get<0>(it);
2189  int64_t currentStrideToExpand = std::get<1>(it);
2190  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2191  reverseResultStrides.push_back(currentStrideToExpand);
2192  currentStrideToExpand =
2193  (SaturatedInteger::wrap(currentStrideToExpand) *
2194  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2195  .asInteger();
2196  }
2197  }
2198  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2199  resultStrides.resize(resultShape.size(), 1);
2200  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2201 }
2202 
2203 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2204  MemRefType srcType, ArrayRef<int64_t> resultShape,
2205  ArrayRef<ReassociationIndices> reassociation) {
2206  if (srcType.getLayout().isIdentity()) {
2207  // If the source is contiguous (i.e., no layout map specified), so is the
2208  // result.
2209  MemRefLayoutAttrInterface layout;
2210  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2211  srcType.getMemorySpace());
2212  }
2213 
2214  // Source may not be contiguous. Compute the layout map.
2215  FailureOr<StridedLayoutAttr> computedLayout =
2216  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2217  if (failed(computedLayout))
2218  return failure();
2219  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2220  srcType.getMemorySpace());
2221 }
2222 
2223 FailureOr<SmallVector<OpFoldResult>>
2224 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2225  MemRefType expandedType,
2226  ArrayRef<ReassociationIndices> reassociation,
2227  ArrayRef<OpFoldResult> inputShape) {
2228  std::optional<SmallVector<OpFoldResult>> outputShape =
2229  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2230  inputShape);
2231  if (!outputShape)
2232  return failure();
2233  return *outputShape;
2234 }
2235 
2236 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2237  Type resultType, Value src,
2238  ArrayRef<ReassociationIndices> reassociation,
2239  ArrayRef<OpFoldResult> outputShape) {
2240  auto [staticOutputShape, dynamicOutputShape] =
2242  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2243  getReassociationIndicesAttribute(builder, reassociation),
2244  dynamicOutputShape, staticOutputShape);
2245 }
2246 
2247 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2248  Type resultType, Value src,
2249  ArrayRef<ReassociationIndices> reassociation) {
2250  SmallVector<OpFoldResult> inputShape =
2251  getMixedSizes(builder, result.location, src);
2252  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2253  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2254  builder, result.location, memrefResultTy, reassociation, inputShape);
2255  // Failure of this assertion usually indicates presence of multiple
2256  // dynamic dimensions in the same reassociation group.
2257  assert(succeeded(outputShape) && "unable to infer output shape");
2258  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2259 }
2260 
2261 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2262  ArrayRef<int64_t> resultShape, Value src,
2263  ArrayRef<ReassociationIndices> reassociation) {
2264  // Only ranked memref source values are supported.
2265  auto srcType = llvm::cast<MemRefType>(src.getType());
2266  FailureOr<MemRefType> resultType =
2267  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2268  // Failure of this assertion usually indicates a problem with the source
2269  // type, e.g., could not get strides/offset.
2270  assert(succeeded(resultType) && "could not compute layout");
2271  build(builder, result, *resultType, src, reassociation);
2272 }
2273 
2274 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2275  ArrayRef<int64_t> resultShape, Value src,
2276  ArrayRef<ReassociationIndices> reassociation,
2277  ArrayRef<OpFoldResult> outputShape) {
2278  // Only ranked memref source values are supported.
2279  auto srcType = llvm::cast<MemRefType>(src.getType());
2280  FailureOr<MemRefType> resultType =
2281  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2282  // Failure of this assertion usually indicates a problem with the source
2283  // type, e.g., could not get strides/offset.
2284  assert(succeeded(resultType) && "could not compute layout");
2285  build(builder, result, *resultType, src, reassociation, outputShape);
2286 }
2287 
2288 LogicalResult ExpandShapeOp::verify() {
2289  MemRefType srcType = getSrcType();
2290  MemRefType resultType = getResultType();
2291 
2292  if (srcType.getRank() > resultType.getRank()) {
2293  auto r0 = srcType.getRank();
2294  auto r1 = resultType.getRank();
2295  return emitOpError("has source rank ")
2296  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2297  << r0 << " > " << r1 << ").";
2298  }
2299 
2300  // Verify result shape.
2301  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2302  resultType.getShape(),
2303  getReassociationIndices(),
2304  /*allowMultipleDynamicDimsPerGroup=*/true)))
2305  return failure();
2306 
2307  // Compute expected result type (including layout map).
2308  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2309  srcType, resultType.getShape(), getReassociationIndices());
2310  if (failed(expectedResultType))
2311  return emitOpError("invalid source layout map");
2312 
2313  // Check actual result type.
2314  if (*expectedResultType != resultType)
2315  return emitOpError("expected expanded type to be ")
2316  << *expectedResultType << " but found " << resultType;
2317 
2318  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2319  return emitOpError("expected number of static shape bounds to be equal to "
2320  "the output rank (")
2321  << resultType.getRank() << ") but found "
2322  << getStaticOutputShape().size() << " inputs instead";
2323 
2324  if ((int64_t)getOutputShape().size() !=
2325  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2326  return emitOpError("mismatch in dynamic dims in output_shape and "
2327  "static_output_shape: static_output_shape has ")
2328  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2329  << " dynamic dims while output_shape has " << getOutputShape().size()
2330  << " values";
2331 
2332  // Verify if provided output shapes are in agreement with output type.
2333  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2334  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2335  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2336  if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2337  return emitOpError("invalid output shape provided at pos ") << pos;
2338  }
2339  }
2340 
2341  return success();
2342 }
2343 
2344 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2345  MLIRContext *context) {
2346  results.add<
2349 }
2350 
2351 /// Compute the layout map after collapsing a given source MemRef type with the
2352 /// specified reassociation indices.
2353 ///
2354 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2355 /// not possible to check this by inspecting a MemRefType in the general case.
2356 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2357 /// be valid (and thus accepted by this function) unless `strict = true`.
2358 static FailureOr<StridedLayoutAttr>
2359 computeCollapsedLayoutMap(MemRefType srcType,
2360  ArrayRef<ReassociationIndices> reassociation,
2361  bool strict = false) {
2362  int64_t srcOffset;
2363  SmallVector<int64_t> srcStrides;
2364  auto srcShape = srcType.getShape();
2365  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2366  return failure();
2367 
2368  // The result stride of a reassociation group is the stride of the last entry
2369  // of the reassociation. (TODO: Should be the minimum stride in the
2370  // reassociation because strides are not necessarily sorted. E.g., when using
2371  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2372  // strides are meaningless and could have any arbitrary value.
2373  SmallVector<int64_t> resultStrides;
2374  resultStrides.reserve(reassociation.size());
2375  for (const ReassociationIndices &reassoc : reassociation) {
2376  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2377  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2378  ref = ref.drop_back();
2379  if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2380  resultStrides.push_back(srcStrides[ref.back()]);
2381  } else {
2382  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2383  // the corresponding stride may have to be skipped. (See above comment.)
2384  // Therefore, the result stride cannot be statically determined and must
2385  // be dynamic.
2386  resultStrides.push_back(ShapedType::kDynamic);
2387  }
2388  }
2389 
2390  // Validate that each reassociation group is contiguous.
2391  unsigned resultStrideIndex = resultStrides.size() - 1;
2392  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2393  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2394  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2395  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2396  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2397 
2398  // Both source and result stride must have the same static value. In that
2399  // case, we can be sure, that the dimensions are collapsible (because they
2400  // are contiguous).
2401  // If `strict = false` (default during op verification), we accept cases
2402  // where one or both strides are dynamic. This is best effort: We reject
2403  // ops where obviously non-contiguous dims are collapsed, but accept ops
2404  // where we cannot be sure statically. Such ops may fail at runtime. See
2405  // the op documentation for details.
2406  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2407  if (strict && (stride.saturated || srcStride.saturated))
2408  return failure();
2409 
2410  // Dimensions of size 1 should be skipped, because their strides are
2411  // meaningless and could have any arbitrary value.
2412  if (srcShape[idx - 1] == 1)
2413  continue;
2414 
2415  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2416  return failure();
2417  }
2418  }
2419  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2420 }
2421 
2422 bool CollapseShapeOp::isGuaranteedCollapsible(
2423  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2424  // MemRefs with identity layout are always collapsible.
2425  if (srcType.getLayout().isIdentity())
2426  return true;
2427 
2428  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2429  /*strict=*/true));
2430 }
2431 
2432 MemRefType CollapseShapeOp::computeCollapsedType(
2433  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2434  SmallVector<int64_t> resultShape;
2435  resultShape.reserve(reassociation.size());
2436  for (const ReassociationIndices &group : reassociation) {
2437  auto groupSize = SaturatedInteger::wrap(1);
2438  for (int64_t srcDim : group)
2439  groupSize =
2440  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2441  resultShape.push_back(groupSize.asInteger());
2442  }
2443 
2444  if (srcType.getLayout().isIdentity()) {
2445  // If the source is contiguous (i.e., no layout map specified), so is the
2446  // result.
2447  MemRefLayoutAttrInterface layout;
2448  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2449  srcType.getMemorySpace());
2450  }
2451 
2452  // Source may not be fully contiguous. Compute the layout map.
2453  // Note: Dimensions that are collapsed into a single dim are assumed to be
2454  // contiguous.
2455  FailureOr<StridedLayoutAttr> computedLayout =
2456  computeCollapsedLayoutMap(srcType, reassociation);
2457  assert(succeeded(computedLayout) &&
2458  "invalid source layout map or collapsing non-contiguous dims");
2459  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2460  srcType.getMemorySpace());
2461 }
2462 
2463 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2464  ArrayRef<ReassociationIndices> reassociation,
2465  ArrayRef<NamedAttribute> attrs) {
2466  auto srcType = llvm::cast<MemRefType>(src.getType());
2467  MemRefType resultType =
2468  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2470  getReassociationIndicesAttribute(b, reassociation));
2471  build(b, result, resultType, src, attrs);
2472 }
2473 
2474 LogicalResult CollapseShapeOp::verify() {
2475  MemRefType srcType = getSrcType();
2476  MemRefType resultType = getResultType();
2477 
2478  if (srcType.getRank() < resultType.getRank()) {
2479  auto r0 = srcType.getRank();
2480  auto r1 = resultType.getRank();
2481  return emitOpError("has source rank ")
2482  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2483  << r0 << " < " << r1 << ").";
2484  }
2485 
2486  // Verify result shape.
2487  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2488  srcType.getShape(), getReassociationIndices(),
2489  /*allowMultipleDynamicDimsPerGroup=*/true)))
2490  return failure();
2491 
2492  // Compute expected result type (including layout map).
2493  MemRefType expectedResultType;
2494  if (srcType.getLayout().isIdentity()) {
2495  // If the source is contiguous (i.e., no layout map specified), so is the
2496  // result.
2497  MemRefLayoutAttrInterface layout;
2498  expectedResultType =
2499  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2500  srcType.getMemorySpace());
2501  } else {
2502  // Source may not be fully contiguous. Compute the layout map.
2503  // Note: Dimensions that are collapsed into a single dim are assumed to be
2504  // contiguous.
2505  FailureOr<StridedLayoutAttr> computedLayout =
2506  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2507  if (failed(computedLayout))
2508  return emitOpError(
2509  "invalid source layout map or collapsing non-contiguous dims");
2510  expectedResultType =
2511  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2512  *computedLayout, srcType.getMemorySpace());
2513  }
2514 
2515  if (expectedResultType != resultType)
2516  return emitOpError("expected collapsed type to be ")
2517  << expectedResultType << " but found " << resultType;
2518 
2519  return success();
2520 }
2521 
2523  : public OpRewritePattern<CollapseShapeOp> {
2524 public:
2526 
2527  LogicalResult matchAndRewrite(CollapseShapeOp op,
2528  PatternRewriter &rewriter) const override {
2529  auto cast = op.getOperand().getDefiningOp<CastOp>();
2530  if (!cast)
2531  return failure();
2532 
2533  if (!CastOp::canFoldIntoConsumerOp(cast))
2534  return failure();
2535 
2536  Type newResultType = CollapseShapeOp::computeCollapsedType(
2537  llvm::cast<MemRefType>(cast.getOperand().getType()),
2538  op.getReassociationIndices());
2539 
2540  if (newResultType == op.getResultType()) {
2541  rewriter.modifyOpInPlace(
2542  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2543  } else {
2544  Value newOp =
2545  CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2546  op.getReassociationIndices());
2547  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2548  }
2549  return success();
2550  }
2551 };
2552 
2553 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2554  MLIRContext *context) {
2555  results.add<
2557  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2558  memref::DimOp, MemRefType>,
2560 }
2561 
2562 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2563  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2564  adaptor.getOperands());
2565 }
2566 
2567 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2568  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2569  adaptor.getOperands());
2570 }
2571 
2572 //===----------------------------------------------------------------------===//
2573 // ReshapeOp
2574 //===----------------------------------------------------------------------===//
2575 
2576 void ReshapeOp::getAsmResultNames(
2577  function_ref<void(Value, StringRef)> setNameFn) {
2578  setNameFn(getResult(), "reshape");
2579 }
2580 
2581 LogicalResult ReshapeOp::verify() {
2582  Type operandType = getSource().getType();
2583  Type resultType = getResult().getType();
2584 
2585  Type operandElementType =
2586  llvm::cast<ShapedType>(operandType).getElementType();
2587  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2588  if (operandElementType != resultElementType)
2589  return emitOpError("element types of source and destination memref "
2590  "types should be the same");
2591 
2592  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2593  if (!operandMemRefType.getLayout().isIdentity())
2594  return emitOpError("source memref type should have identity affine map");
2595 
2596  int64_t shapeSize =
2597  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2598  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2599  if (resultMemRefType) {
2600  if (!resultMemRefType.getLayout().isIdentity())
2601  return emitOpError("result memref type should have identity affine map");
2602  if (shapeSize == ShapedType::kDynamic)
2603  return emitOpError("cannot use shape operand with dynamic length to "
2604  "reshape to statically-ranked memref type");
2605  if (shapeSize != resultMemRefType.getRank())
2606  return emitOpError(
2607  "length of shape operand differs from the result's memref rank");
2608  }
2609  return success();
2610 }
2611 
2612 //===----------------------------------------------------------------------===//
2613 // StoreOp
2614 //===----------------------------------------------------------------------===//
2615 
2616 LogicalResult StoreOp::verify() {
2617  if (getNumOperands() != 2 + getMemRefType().getRank())
2618  return emitOpError("store index operand count not equal to memref rank");
2619 
2620  return success();
2621 }
2622 
2623 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2624  SmallVectorImpl<OpFoldResult> &results) {
2625  /// store(memrefcast) -> store
2626  return foldMemRefCast(*this, getValueToStore());
2627 }
2628 
2629 //===----------------------------------------------------------------------===//
2630 // SubViewOp
2631 //===----------------------------------------------------------------------===//
2632 
2633 void SubViewOp::getAsmResultNames(
2634  function_ref<void(Value, StringRef)> setNameFn) {
2635  setNameFn(getResult(), "subview");
2636 }
2637 
2638 /// A subview result type can be fully inferred from the source type and the
2639 /// static representation of offsets, sizes and strides. Special sentinels
2640 /// encode the dynamic case.
2641 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2642  ArrayRef<int64_t> staticOffsets,
2643  ArrayRef<int64_t> staticSizes,
2644  ArrayRef<int64_t> staticStrides) {
2645  unsigned rank = sourceMemRefType.getRank();
2646  (void)rank;
2647  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2648  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2649  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2650 
2651  // Extract source offset and strides.
2652  auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2653 
2654  // Compute target offset whose value is:
2655  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2656  int64_t targetOffset = sourceOffset;
2657  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2658  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2659  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2660  SaturatedInteger::wrap(staticOffset) *
2661  SaturatedInteger::wrap(sourceStride))
2662  .asInteger();
2663  }
2664 
2665  // Compute target stride whose value is:
2666  // `sourceStrides_i * staticStrides_i`.
2667  SmallVector<int64_t, 4> targetStrides;
2668  targetStrides.reserve(staticOffsets.size());
2669  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2670  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2671  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2672  SaturatedInteger::wrap(staticStride))
2673  .asInteger());
2674  }
2675 
2676  // The type is now known.
2677  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2678  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2679  targetOffset, targetStrides),
2680  sourceMemRefType.getMemorySpace());
2681 }
2682 
2683 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2684  ArrayRef<OpFoldResult> offsets,
2685  ArrayRef<OpFoldResult> sizes,
2686  ArrayRef<OpFoldResult> strides) {
2687  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2688  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2689  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2690  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2691  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2692  if (!hasValidSizesOffsets(staticOffsets))
2693  return {};
2694  if (!hasValidSizesOffsets(staticSizes))
2695  return {};
2696  if (!hasValidStrides(staticStrides))
2697  return {};
2698  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2699  staticSizes, staticStrides);
2700 }
2701 
2702 MemRefType SubViewOp::inferRankReducedResultType(
2703  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2704  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2705  ArrayRef<int64_t> strides) {
2706  MemRefType inferredType =
2707  inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2708  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2709  "expected ");
2710  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2711  return inferredType;
2712 
2713  // Compute which dimensions are dropped.
2714  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2715  computeRankReductionMask(inferredType.getShape(), resultShape);
2716  assert(dimsToProject.has_value() && "invalid rank reduction");
2717 
2718  // Compute the layout and result type.
2719  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2720  SmallVector<int64_t> rankReducedStrides;
2721  rankReducedStrides.reserve(resultShape.size());
2722  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2723  if (!dimsToProject->contains(idx))
2724  rankReducedStrides.push_back(value);
2725  }
2726  return MemRefType::get(resultShape, inferredType.getElementType(),
2727  StridedLayoutAttr::get(inferredLayout.getContext(),
2728  inferredLayout.getOffset(),
2729  rankReducedStrides),
2730  inferredType.getMemorySpace());
2731 }
2732 
2733 MemRefType SubViewOp::inferRankReducedResultType(
2734  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2736  ArrayRef<OpFoldResult> strides) {
2737  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2738  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2739  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2740  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2741  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2742  return SubViewOp::inferRankReducedResultType(
2743  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2744  staticStrides);
2745 }
2746 
2747 // Build a SubViewOp with mixed static and dynamic entries and custom result
2748 // type. If the type passed is nullptr, it is inferred.
2749 void SubViewOp::build(OpBuilder &b, OperationState &result,
2750  MemRefType resultType, Value source,
2751  ArrayRef<OpFoldResult> offsets,
2752  ArrayRef<OpFoldResult> sizes,
2753  ArrayRef<OpFoldResult> strides,
2754  ArrayRef<NamedAttribute> attrs) {
2755  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2756  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2757  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2758  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2759  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2760  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2761  // Structuring implementation this way avoids duplication between builders.
2762  if (!resultType) {
2763  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2764  staticSizes, staticStrides);
2765  }
2766  result.addAttributes(attrs);
2767  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2768  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2769  b.getDenseI64ArrayAttr(staticSizes),
2770  b.getDenseI64ArrayAttr(staticStrides));
2771 }
2772 
2773 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2774 // type.
2775 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2776  ArrayRef<OpFoldResult> offsets,
2777  ArrayRef<OpFoldResult> sizes,
2778  ArrayRef<OpFoldResult> strides,
2779  ArrayRef<NamedAttribute> attrs) {
2780  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2781 }
2782 
2783 // Build a SubViewOp with static entries and inferred result type.
2784 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2785  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2786  ArrayRef<int64_t> strides,
2787  ArrayRef<NamedAttribute> attrs) {
2788  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2789  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2790  return b.getI64IntegerAttr(v);
2791  }));
2792  SmallVector<OpFoldResult> sizeValues =
2793  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2794  return b.getI64IntegerAttr(v);
2795  }));
2796  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2797  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2798  return b.getI64IntegerAttr(v);
2799  }));
2800  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2801 }
2802 
2803 // Build a SubViewOp with dynamic entries and custom result type. If the
2804 // type passed is nullptr, it is inferred.
2805 void SubViewOp::build(OpBuilder &b, OperationState &result,
2806  MemRefType resultType, Value source,
2807  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2808  ArrayRef<int64_t> strides,
2809  ArrayRef<NamedAttribute> attrs) {
2810  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2811  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2812  return b.getI64IntegerAttr(v);
2813  }));
2814  SmallVector<OpFoldResult> sizeValues =
2815  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2816  return b.getI64IntegerAttr(v);
2817  }));
2818  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2819  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2820  return b.getI64IntegerAttr(v);
2821  }));
2822  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2823  attrs);
2824 }
2825 
2826 // Build a SubViewOp with dynamic entries and custom result type. If the type
2827 // passed is nullptr, it is inferred.
2828 void SubViewOp::build(OpBuilder &b, OperationState &result,
2829  MemRefType resultType, Value source, ValueRange offsets,
2830  ValueRange sizes, ValueRange strides,
2831  ArrayRef<NamedAttribute> attrs) {
2832  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2833  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2834  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2835  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2836  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2837  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2838  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2839 }
2840 
2841 // Build a SubViewOp with dynamic entries and inferred result type.
2842 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2843  ValueRange offsets, ValueRange sizes, ValueRange strides,
2844  ArrayRef<NamedAttribute> attrs) {
2845  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2846 }
2847 
2848 /// For ViewLikeOpInterface.
2849 Value SubViewOp::getViewSource() { return getSource(); }
2850 
2851 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2852 /// static value).
2853 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2854  int64_t t1Offset, t2Offset;
2855  SmallVector<int64_t> t1Strides, t2Strides;
2856  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2857  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2858  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2859 }
2860 
2861 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2862 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2863 /// marked as dropped in `droppedDims`.
2864 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2865  const llvm::SmallBitVector &droppedDims) {
2866  assert(size_t(t1.getRank()) == droppedDims.size() &&
2867  "incorrect number of bits");
2868  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2869  "incorrect number of dropped dims");
2870  int64_t t1Offset, t2Offset;
2871  SmallVector<int64_t> t1Strides, t2Strides;
2872  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2873  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2874  if (failed(res1) || failed(res2))
2875  return false;
2876  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2877  if (droppedDims[i])
2878  continue;
2879  if (t1Strides[i] != t2Strides[j])
2880  return false;
2881  ++j;
2882  }
2883  return true;
2884 }
2885 
2887  SubViewOp op, Type expectedType) {
2888  auto memrefType = llvm::cast<ShapedType>(expectedType);
2889  switch (result) {
2891  return success();
2893  return op->emitError("expected result rank to be smaller or equal to ")
2894  << "the source rank, but got " << op.getType();
2896  return op->emitError("expected result type to be ")
2897  << expectedType
2898  << " or a rank-reduced version. (mismatch of result sizes), but got "
2899  << op.getType();
2901  return op->emitError("expected result element type to be ")
2902  << memrefType.getElementType() << ", but got " << op.getType();
2904  return op->emitError(
2905  "expected result and source memory spaces to match, but got ")
2906  << op.getType();
2908  return op->emitError("expected result type to be ")
2909  << expectedType
2910  << " or a rank-reduced version. (mismatch of result layout), but "
2911  "got "
2912  << op.getType();
2913  }
2914  llvm_unreachable("unexpected subview verification result");
2915 }
2916 
2917 /// Verifier for SubViewOp.
2918 LogicalResult SubViewOp::verify() {
2919  MemRefType baseType = getSourceType();
2920  MemRefType subViewType = getType();
2921  ArrayRef<int64_t> staticOffsets = getStaticOffsets();
2922  ArrayRef<int64_t> staticSizes = getStaticSizes();
2923  ArrayRef<int64_t> staticStrides = getStaticStrides();
2924 
2925  // The base memref and the view memref should be in the same memory space.
2926  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2927  return emitError("different memory spaces specified for base memref "
2928  "type ")
2929  << baseType << " and subview memref type " << subViewType;
2930 
2931  // Verify that the base memref type has a strided layout map.
2932  if (!baseType.isStrided())
2933  return emitError("base type ") << baseType << " is not strided";
2934 
2935  // Compute the expected result type, assuming that there are no rank
2936  // reductions.
2937  MemRefType expectedType = SubViewOp::inferResultType(
2938  baseType, staticOffsets, staticSizes, staticStrides);
2939 
2940  // Verify all properties of a shaped type: rank, element type and dimension
2941  // sizes. This takes into account potential rank reductions.
2942  auto shapedTypeVerification = isRankReducedType(
2943  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2944  if (shapedTypeVerification != SliceVerificationResult::Success)
2945  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2946 
2947  // Make sure that the memory space did not change.
2948  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2950  *this, expectedType);
2951 
2952  // Verify the offset of the layout map.
2953  if (!haveCompatibleOffsets(expectedType, subViewType))
2955  *this, expectedType);
2956 
2957  // The only thing that's left to verify now are the strides. First, compute
2958  // the unused dimensions due to rank reductions. We have to look at sizes and
2959  // strides to decide which dimensions were dropped. This function also
2960  // partially verifies strides in case of rank reductions.
2961  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
2962  getMixedSizes());
2963  if (failed(unusedDims))
2965  *this, expectedType);
2966 
2967  // Strides must match.
2968  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
2970  *this, expectedType);
2971 
2972  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2973  // to the base memref.
2974  SliceBoundsVerificationResult boundsResult =
2975  verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
2976  staticStrides, /*generateErrorMessage=*/true);
2977  if (!boundsResult.isValid)
2978  return getOperation()->emitError(boundsResult.errorMessage);
2979 
2980  return success();
2981 }
2982 
2983 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2984  return os << "range " << range.offset << ":" << range.size << ":"
2985  << range.stride;
2986 }
2987 
2988 /// Return the list of Range (i.e. offset, size, stride). Each Range
2989 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2990 /// with `b` at location `loc`.
2991 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2992  OpBuilder &b, Location loc) {
2993  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2994  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2995  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2997  unsigned rank = ranks[0];
2998  res.reserve(rank);
2999  for (unsigned idx = 0; idx < rank; ++idx) {
3000  Value offset =
3001  op.isDynamicOffset(idx)
3002  ? op.getDynamicOffset(idx)
3003  : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
3004  Value size =
3005  op.isDynamicSize(idx)
3006  ? op.getDynamicSize(idx)
3007  : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
3008  Value stride =
3009  op.isDynamicStride(idx)
3010  ? op.getDynamicStride(idx)
3011  : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
3012  res.emplace_back(Range{offset, size, stride});
3013  }
3014  return res;
3015 }
3016 
3017 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3018 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3019 /// the rank of the inferred result type if `currentResultType` is lower rank
3020 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3021 /// together with the result type. In this case, it is important to compute
3022 /// the dropped dimensions using `currentSourceType` whose strides align with
3023 /// `currentResultType`.
3025  MemRefType currentResultType, MemRefType currentSourceType,
3026  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3027  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3028  MemRefType nonRankReducedType = SubViewOp::inferResultType(
3029  sourceType, mixedOffsets, mixedSizes, mixedStrides);
3030  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3031  currentSourceType, currentResultType, mixedSizes);
3032  if (failed(unusedDims))
3033  return nullptr;
3034 
3035  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3036  SmallVector<int64_t> shape, strides;
3037  unsigned numDimsAfterReduction =
3038  nonRankReducedType.getRank() - unusedDims->count();
3039  shape.reserve(numDimsAfterReduction);
3040  strides.reserve(numDimsAfterReduction);
3041  for (const auto &[idx, size, stride] :
3042  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3043  nonRankReducedType.getShape(), layout.getStrides())) {
3044  if (unusedDims->test(idx))
3045  continue;
3046  shape.push_back(size);
3047  strides.push_back(stride);
3048  }
3049 
3050  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3051  StridedLayoutAttr::get(sourceType.getContext(),
3052  layout.getOffset(), strides),
3053  nonRankReducedType.getMemorySpace());
3054 }
3055 
3057  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3058  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3059  unsigned rank = memrefType.getRank();
3060  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3061  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3062  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3063  MemRefType targetType = SubViewOp::inferRankReducedResultType(
3064  targetShape, memrefType, offsets, sizes, strides);
3065  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3066  sizes, strides);
3067 }
3068 
3069 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3070  Value value,
3071  ArrayRef<int64_t> desiredShape) {
3072  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3073  assert(sourceMemrefType && "not a ranked memref type");
3074  auto sourceShape = sourceMemrefType.getShape();
3075  if (sourceShape.equals(desiredShape))
3076  return value;
3077  auto maybeRankReductionMask =
3078  mlir::computeRankReductionMask(sourceShape, desiredShape);
3079  if (!maybeRankReductionMask)
3080  return failure();
3081  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3082 }
3083 
3084 /// Helper method to check if a `subview` operation is trivially a no-op. This
3085 /// is the case if the all offsets are zero, all strides are 1, and the source
3086 /// shape is same as the size of the subview. In such cases, the subview can
3087 /// be folded into its source.
3088 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3089  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3090  return false;
3091 
3092  auto mixedOffsets = subViewOp.getMixedOffsets();
3093  auto mixedSizes = subViewOp.getMixedSizes();
3094  auto mixedStrides = subViewOp.getMixedStrides();
3095 
3096  // Check offsets are zero.
3097  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3098  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3099  return !intValue || intValue.value() != 0;
3100  }))
3101  return false;
3102 
3103  // Check strides are one.
3104  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3105  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3106  return !intValue || intValue.value() != 1;
3107  }))
3108  return false;
3109 
3110  // Check all size values are static and matches the (static) source shape.
3111  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3112  for (const auto &size : llvm::enumerate(mixedSizes)) {
3113  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3114  if (!intValue || *intValue != sourceShape[size.index()])
3115  return false;
3116  }
3117  // All conditions met. The `SubViewOp` is foldable as a no-op.
3118  return true;
3119 }
3120 
3121 namespace {
3122 /// Pattern to rewrite a subview op with MemRefCast arguments.
3123 /// This essentially pushes memref.cast past its consuming subview when
3124 /// `canFoldIntoConsumerOp` is true.
3125 ///
3126 /// Example:
3127 /// ```
3128 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3129 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3130 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3131 /// ```
3132 /// is rewritten into:
3133 /// ```
3134 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3135 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3136 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3137 /// ```
3138 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3139 public:
3141 
3142  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3143  PatternRewriter &rewriter) const override {
3144  // Any constant operand, just return to let SubViewOpConstantFolder kick
3145  // in.
3146  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3147  return matchPattern(operand, matchConstantIndex());
3148  }))
3149  return failure();
3150 
3151  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3152  if (!castOp)
3153  return failure();
3154 
3155  if (!CastOp::canFoldIntoConsumerOp(castOp))
3156  return failure();
3157 
3158  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3159  // the MemRefCastOp source operand type to infer the result type and the
3160  // current SubViewOp source operand type to compute the dropped dimensions
3161  // if the operation is rank-reducing.
3162  auto resultType = getCanonicalSubViewResultType(
3163  subViewOp.getType(), subViewOp.getSourceType(),
3164  llvm::cast<MemRefType>(castOp.getSource().getType()),
3165  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3166  subViewOp.getMixedStrides());
3167  if (!resultType)
3168  return failure();
3169 
3170  Value newSubView = SubViewOp::create(
3171  rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3172  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3173  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3174  subViewOp.getStaticStrides());
3175  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3176  newSubView);
3177  return success();
3178  }
3179 };
3180 
3181 /// Canonicalize subview ops that are no-ops. When the source shape is not
3182 /// same as a result shape due to use of `affine_map`.
3183 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3184 public:
3186 
3187  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3188  PatternRewriter &rewriter) const override {
3189  if (!isTrivialSubViewOp(subViewOp))
3190  return failure();
3191  if (subViewOp.getSourceType() == subViewOp.getType()) {
3192  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3193  return success();
3194  }
3195  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3196  subViewOp.getSource());
3197  return success();
3198  }
3199 };
3200 } // namespace
3201 
3202 /// Return the canonical type of the result of a subview.
3204  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3205  ArrayRef<OpFoldResult> mixedSizes,
3206  ArrayRef<OpFoldResult> mixedStrides) {
3207  // Infer a memref type without taking into account any rank reductions.
3208  MemRefType resTy = SubViewOp::inferResultType(
3209  op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3210  if (!resTy)
3211  return {};
3212  MemRefType nonReducedType = resTy;
3213 
3214  // Directly return the non-rank reduced type if there are no dropped dims.
3215  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3216  if (droppedDims.none())
3217  return nonReducedType;
3218 
3219  // Take the strides and offset from the non-rank reduced type.
3220  auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3221 
3222  // Drop dims from shape and strides.
3223  SmallVector<int64_t> targetShape;
3224  SmallVector<int64_t> targetStrides;
3225  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3226  if (droppedDims.test(i))
3227  continue;
3228  targetStrides.push_back(nonReducedStrides[i]);
3229  targetShape.push_back(nonReducedType.getDimSize(i));
3230  }
3231 
3232  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3233  StridedLayoutAttr::get(nonReducedType.getContext(),
3234  offset, targetStrides),
3235  nonReducedType.getMemorySpace());
3236  }
3237 };
3238 
3239 /// A canonicalizer wrapper to replace SubViewOps.
3241  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3242  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3243  }
3244 };
3245 
3246 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3247  MLIRContext *context) {
3248  results
3251  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3252 }
3253 
3254 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3255  MemRefType sourceMemrefType = getSource().getType();
3256  MemRefType resultMemrefType = getResult().getType();
3257  auto resultLayout =
3258  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3259 
3260  if (resultMemrefType == sourceMemrefType &&
3261  resultMemrefType.hasStaticShape() &&
3262  (!resultLayout || resultLayout.hasStaticLayout())) {
3263  return getViewSource();
3264  }
3265 
3266  // Fold subview(subview(x)), where both subviews have the same size and the
3267  // second subview's offsets are all zero. (I.e., the second subview is a
3268  // no-op.)
3269  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3270  auto srcSizes = srcSubview.getMixedSizes();
3271  auto sizes = getMixedSizes();
3272  auto offsets = getMixedOffsets();
3273  bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3274  auto strides = getMixedStrides();
3275  bool allStridesOne = llvm::all_of(strides, isOneInteger);
3276  bool allSizesSame = llvm::equal(sizes, srcSizes);
3277  if (allOffsetsZero && allStridesOne && allSizesSame &&
3278  resultMemrefType == sourceMemrefType)
3279  return getViewSource();
3280  }
3281 
3282  return {};
3283 }
3284 
3285 //===----------------------------------------------------------------------===//
3286 // TransposeOp
3287 //===----------------------------------------------------------------------===//
3288 
3289 void TransposeOp::getAsmResultNames(
3290  function_ref<void(Value, StringRef)> setNameFn) {
3291  setNameFn(getResult(), "transpose");
3292 }
3293 
3294 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3295 static MemRefType inferTransposeResultType(MemRefType memRefType,
3296  AffineMap permutationMap) {
3297  auto originalSizes = memRefType.getShape();
3298  auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3299  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3300 
3301  // Compute permuted sizes and strides.
3302  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3303  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3304 
3305  return MemRefType::Builder(memRefType)
3306  .setShape(sizes)
3307  .setLayout(
3308  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3309 }
3310 
3311 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3312  AffineMapAttr permutation,
3313  ArrayRef<NamedAttribute> attrs) {
3314  auto permutationMap = permutation.getValue();
3315  assert(permutationMap);
3316 
3317  auto memRefType = llvm::cast<MemRefType>(in.getType());
3318  // Compute result type.
3319  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3320 
3321  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3322  build(b, result, resultType, in, attrs);
3323 }
3324 
3325 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3327  p << " " << getIn() << " " << getPermutation();
3328  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3329  p << " : " << getIn().getType() << " to " << getType();
3330 }
3331 
3332 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3334  AffineMap permutation;
3335  MemRefType srcType, dstType;
3336  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3337  parser.parseOptionalAttrDict(result.attributes) ||
3338  parser.parseColonType(srcType) ||
3339  parser.resolveOperand(in, srcType, result.operands) ||
3340  parser.parseKeywordType("to", dstType) ||
3341  parser.addTypeToList(dstType, result.types))
3342  return failure();
3343 
3344  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3345  AffineMapAttr::get(permutation));
3346  return success();
3347 }
3348 
3349 LogicalResult TransposeOp::verify() {
3350  if (!getPermutation().isPermutation())
3351  return emitOpError("expected a permutation map");
3352  if (getPermutation().getNumDims() != getIn().getType().getRank())
3353  return emitOpError("expected a permutation map of same rank as the input");
3354 
3355  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3356  auto resultType = llvm::cast<MemRefType>(getType());
3357  auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3358  .canonicalizeStridedLayout();
3359 
3360  if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3361  return emitOpError("result type ")
3362  << resultType
3363  << " is not equivalent to the canonical transposed input type "
3364  << canonicalResultType;
3365  return success();
3366 }
3367 
3368 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3369  // First check for identity permutation, we can fold it away if input and
3370  // result types are identical already.
3371  if (getPermutation().isIdentity() && getType() == getIn().getType())
3372  return getIn();
3373  // Fold two consecutive memref.transpose Ops into one by composing their
3374  // permutation maps.
3375  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3376  AffineMap composedPermutation =
3377  getPermutation().compose(otherTransposeOp.getPermutation());
3378  getInMutable().assign(otherTransposeOp.getIn());
3379  setPermutation(composedPermutation);
3380  return getResult();
3381  }
3382  return {};
3383 }
3384 
3385 //===----------------------------------------------------------------------===//
3386 // ViewOp
3387 //===----------------------------------------------------------------------===//
3388 
3389 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3390  setNameFn(getResult(), "view");
3391 }
3392 
3393 LogicalResult ViewOp::verify() {
3394  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3395  auto viewType = getType();
3396 
3397  // The base memref should have identity layout map (or none).
3398  if (!baseType.getLayout().isIdentity())
3399  return emitError("unsupported map for base memref type ") << baseType;
3400 
3401  // The result memref should have identity layout map (or none).
3402  if (!viewType.getLayout().isIdentity())
3403  return emitError("unsupported map for result memref type ") << viewType;
3404 
3405  // The base memref and the view memref should be in the same memory space.
3406  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3407  return emitError("different memory spaces specified for base memref "
3408  "type ")
3409  << baseType << " and view memref type " << viewType;
3410 
3411  // Verify that we have the correct number of sizes for the result type.
3412  unsigned numDynamicDims = viewType.getNumDynamicDims();
3413  if (getSizes().size() != numDynamicDims)
3414  return emitError("incorrect number of size operands for type ") << viewType;
3415 
3416  return success();
3417 }
3418 
3419 Value ViewOp::getViewSource() { return getSource(); }
3420 
3421 OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3422  MemRefType sourceMemrefType = getSource().getType();
3423  MemRefType resultMemrefType = getResult().getType();
3424 
3425  if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3426  return getViewSource();
3427 
3428  return {};
3429 }
3430 
3431 namespace {
3432 
3433 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3435 
3436  LogicalResult matchAndRewrite(ViewOp viewOp,
3437  PatternRewriter &rewriter) const override {
3438  // Return if none of the operands are constants.
3439  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3440  return matchPattern(operand, matchConstantIndex());
3441  }))
3442  return failure();
3443 
3444  // Get result memref type.
3445  auto memrefType = viewOp.getType();
3446 
3447  // Get offset from old memref view type 'memRefType'.
3448  int64_t oldOffset;
3449  SmallVector<int64_t, 4> oldStrides;
3450  if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3451  return failure();
3452  assert(oldOffset == 0 && "Expected 0 offset");
3453 
3454  SmallVector<Value, 4> newOperands;
3455 
3456  // Offset cannot be folded into result type.
3457 
3458  // Fold any dynamic dim operands which are produced by a constant.
3459  SmallVector<int64_t, 4> newShapeConstants;
3460  newShapeConstants.reserve(memrefType.getRank());
3461 
3462  unsigned dynamicDimPos = 0;
3463  unsigned rank = memrefType.getRank();
3464  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3465  int64_t dimSize = memrefType.getDimSize(dim);
3466  // If this is already static dimension, keep it.
3467  if (ShapedType::isStatic(dimSize)) {
3468  newShapeConstants.push_back(dimSize);
3469  continue;
3470  }
3471  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3472  if (auto constantIndexOp =
3473  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3474  // Dynamic shape dimension will be folded.
3475  newShapeConstants.push_back(constantIndexOp.value());
3476  } else {
3477  // Dynamic shape dimension not folded; copy operand from old memref.
3478  newShapeConstants.push_back(dimSize);
3479  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3480  }
3481  dynamicDimPos++;
3482  }
3483 
3484  // Create new memref type with constant folded dims.
3485  MemRefType newMemRefType =
3486  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3487  // Nothing new, don't fold.
3488  if (newMemRefType == memrefType)
3489  return failure();
3490 
3491  // Create new ViewOp.
3492  auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
3493  viewOp.getOperand(0), viewOp.getByteShift(),
3494  newOperands);
3495  // Insert a cast so we have the same type as the old memref type.
3496  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3497  return success();
3498  }
3499 };
3500 
3501 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3503 
3504  LogicalResult matchAndRewrite(ViewOp viewOp,
3505  PatternRewriter &rewriter) const override {
3506  Value memrefOperand = viewOp.getOperand(0);
3507  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3508  if (!memrefCastOp)
3509  return failure();
3510  Value allocOperand = memrefCastOp.getOperand();
3511  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3512  if (!allocOp)
3513  return failure();
3514  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3515  viewOp.getByteShift(),
3516  viewOp.getSizes());
3517  return success();
3518  }
3519 };
3520 
3521 } // namespace
3522 
3523 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3524  MLIRContext *context) {
3525  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3526 }
3527 
3528 //===----------------------------------------------------------------------===//
3529 // AtomicRMWOp
3530 //===----------------------------------------------------------------------===//
3531 
3532 LogicalResult AtomicRMWOp::verify() {
3533  if (getMemRefType().getRank() != getNumOperands() - 2)
3534  return emitOpError(
3535  "expects the number of subscripts to be equal to memref rank");
3536  switch (getKind()) {
3537  case arith::AtomicRMWKind::addf:
3538  case arith::AtomicRMWKind::maximumf:
3539  case arith::AtomicRMWKind::minimumf:
3540  case arith::AtomicRMWKind::mulf:
3541  if (!llvm::isa<FloatType>(getValue().getType()))
3542  return emitOpError() << "with kind '"
3543  << arith::stringifyAtomicRMWKind(getKind())
3544  << "' expects a floating-point type";
3545  break;
3546  case arith::AtomicRMWKind::addi:
3547  case arith::AtomicRMWKind::maxs:
3548  case arith::AtomicRMWKind::maxu:
3549  case arith::AtomicRMWKind::mins:
3550  case arith::AtomicRMWKind::minu:
3551  case arith::AtomicRMWKind::muli:
3552  case arith::AtomicRMWKind::ori:
3553  case arith::AtomicRMWKind::xori:
3554  case arith::AtomicRMWKind::andi:
3555  if (!llvm::isa<IntegerType>(getValue().getType()))
3556  return emitOpError() << "with kind '"
3557  << arith::stringifyAtomicRMWKind(getKind())
3558  << "' expects an integer type";
3559  break;
3560  default:
3561  break;
3562  }
3563  return success();
3564 }
3565 
3566 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3567  /// atomicrmw(memrefcast) -> atomicrmw
3568  if (succeeded(foldMemRefCast(*this, getValue())))
3569  return getResult();
3570  return OpFoldResult();
3571 }
3572 
3573 //===----------------------------------------------------------------------===//
3574 // TableGen'd op method definitions
3575 //===----------------------------------------------------------------------===//
3576 
3577 #define GET_OP_CLASSES
3578 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition: MemRefOps.cpp:96
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1471
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2071
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:377
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3295
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:358
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2853
static LogicalResult FoldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
Definition: MemRefOps.cpp:761
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1308
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
Definition: MemRefOps.cpp:2886
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3024
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1485
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:848
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3088
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:400
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2864
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:833
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2164
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2359
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:129
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:140
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition: Block.cpp:250
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:166
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:99
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:50
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:193
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:226
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:68
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3056
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:2991
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:23
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:448
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:451
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:408
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:411
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2527
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3240
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3241
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3203
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3204
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.