MLIR  22.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::memref;
29 
30 /// Materialize a single constant operation from a given attribute value with
31 /// the desired resultant type.
33  Attribute value, Type type,
34  Location loc) {
35  return arith::ConstantOp::materialize(builder, value, type, loc);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Common canonicalization pattern support logic
40 //===----------------------------------------------------------------------===//
41 
42 /// This is a common class used for patterns of the form
43 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
44 /// into the root operation directly.
45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46  bool folded = false;
47  for (OpOperand &operand : op->getOpOperands()) {
48  auto cast = operand.get().getDefiningOp<CastOp>();
49  if (cast && operand.get() != inner &&
50  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51  operand.set(cast.getOperand());
52  folded = true;
53  }
54  }
55  return success(folded);
56 }
57 
58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
59 /// type.
61  if (auto memref = llvm::dyn_cast<MemRefType>(type))
62  return RankedTensorType::get(memref.getShape(), memref.getElementType());
63  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
64  return UnrankedTensorType::get(memref.getElementType());
65  return NoneType::get(type.getContext());
66 }
67 
69  int64_t dim) {
70  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that sets values[i] to constValues[i] if the latter is a
91 /// static value, as indicated by ShapedType::kDynamic.
92 ///
93 /// If constValues[i] is dynamic, tries to extract a constant value from
94 /// value[i] to allow for additional folding opportunities. Also convertes all
95 /// existing attributes to index attributes. (They may be i64 attributes.)
97  ArrayRef<int64_t> constValues) {
98  assert(constValues.size() == values.size() &&
99  "incorrect number of const values");
100  for (auto [i, cstVal] : llvm::enumerate(constValues)) {
101  Builder builder(values[i].getContext());
102  if (ShapedType::isStatic(cstVal)) {
103  // Constant value is known, use it directly.
104  values[i] = builder.getIndexAttr(cstVal);
105  continue;
106  }
107  if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
108  // Try to extract a constant or convert an existing to index.
109  values[i] = builder.getIndexAttr(*cst);
110  }
111  }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // AllocOp / AllocaOp
116 //===----------------------------------------------------------------------===//
117 
118 void AllocOp::getAsmResultNames(
119  function_ref<void(Value, StringRef)> setNameFn) {
120  setNameFn(getResult(), "alloc");
121 }
122 
123 void AllocaOp::getAsmResultNames(
124  function_ref<void(Value, StringRef)> setNameFn) {
125  setNameFn(getResult(), "alloca");
126 }
127 
128 template <typename AllocLikeOp>
129 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
130  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
131  "applies to only alloc or alloca");
132  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
133  if (!memRefType)
134  return op.emitOpError("result must be a memref");
135 
136  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
137  return op.emitOpError("dimension operand count does not equal memref "
138  "dynamic dimension count");
139 
140  unsigned numSymbols = 0;
141  if (!memRefType.getLayout().isIdentity())
142  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
143  if (op.getSymbolOperands().size() != numSymbols)
144  return op.emitOpError("symbol operand count does not equal memref symbol "
145  "count: expected ")
146  << numSymbols << ", got " << op.getSymbolOperands().size();
147 
148  return success();
149 }
150 
151 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
152 
153 LogicalResult AllocaOp::verify() {
154  // An alloca op needs to have an ancestor with an allocation scope trait.
155  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
156  return emitOpError(
157  "requires an ancestor op with AutomaticAllocationScope trait");
158 
159  return verifyAllocLikeOp(*this);
160 }
161 
162 namespace {
163 /// Fold constant dimensions into an alloc like operation.
164 template <typename AllocLikeOp>
165 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
167 
168  LogicalResult matchAndRewrite(AllocLikeOp alloc,
169  PatternRewriter &rewriter) const override {
170  // Check to see if any dimensions operands are constants. If so, we can
171  // substitute and drop them.
172  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
173  APInt constSizeArg;
174  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
175  return false;
176  return constSizeArg.isNonNegative();
177  }))
178  return failure();
179 
180  auto memrefType = alloc.getType();
181 
182  // Ok, we have one or more constant operands. Collect the non-constant ones
183  // and keep track of the resultant memref type to build.
184  SmallVector<int64_t, 4> newShapeConstants;
185  newShapeConstants.reserve(memrefType.getRank());
186  SmallVector<Value, 4> dynamicSizes;
187 
188  unsigned dynamicDimPos = 0;
189  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
190  int64_t dimSize = memrefType.getDimSize(dim);
191  // If this is already static dimension, keep it.
192  if (ShapedType::isStatic(dimSize)) {
193  newShapeConstants.push_back(dimSize);
194  continue;
195  }
196  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
197  APInt constSizeArg;
198  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
199  constSizeArg.isNonNegative()) {
200  // Dynamic shape dimension will be folded.
201  newShapeConstants.push_back(constSizeArg.getZExtValue());
202  } else {
203  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
204  newShapeConstants.push_back(ShapedType::kDynamic);
205  dynamicSizes.push_back(dynamicSize);
206  }
207  dynamicDimPos++;
208  }
209 
210  // Create new memref type (which will have fewer dynamic dimensions).
211  MemRefType newMemRefType =
212  MemRefType::Builder(memrefType).setShape(newShapeConstants);
213  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
214 
215  // Create and insert the alloc op for the new memref.
216  auto newAlloc = rewriter.create<AllocLikeOp>(
217  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
218  alloc.getAlignmentAttr());
219  // Insert a cast so we have the same type as the old alloc.
220  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
221  return success();
222  }
223 };
224 
225 /// Fold alloc operations with no users or only store and dealloc uses.
226 template <typename T>
227 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
229 
230  LogicalResult matchAndRewrite(T alloc,
231  PatternRewriter &rewriter) const override {
232  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
233  if (auto storeOp = dyn_cast<StoreOp>(op))
234  return storeOp.getValue() == alloc;
235  return !isa<DeallocOp>(op);
236  }))
237  return failure();
238 
239  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
240  rewriter.eraseOp(user);
241 
242  rewriter.eraseOp(alloc);
243  return success();
244  }
245 };
246 } // namespace
247 
248 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
249  MLIRContext *context) {
250  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
251 }
252 
253 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
254  MLIRContext *context) {
255  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
256  context);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // ReallocOp
261 //===----------------------------------------------------------------------===//
262 
263 LogicalResult ReallocOp::verify() {
264  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
265  MemRefType resultType = getType();
266 
267  // The source memref should have identity layout (or none).
268  if (!sourceType.getLayout().isIdentity())
269  return emitError("unsupported layout for source memref type ")
270  << sourceType;
271 
272  // The result memref should have identity layout (or none).
273  if (!resultType.getLayout().isIdentity())
274  return emitError("unsupported layout for result memref type ")
275  << resultType;
276 
277  // The source memref and the result memref should be in the same memory space.
278  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
279  return emitError("different memory spaces specified for source memref "
280  "type ")
281  << sourceType << " and result memref type " << resultType;
282 
283  // The source memref and the result memref should have the same element type.
284  if (sourceType.getElementType() != resultType.getElementType())
285  return emitError("different element types specified for source memref "
286  "type ")
287  << sourceType << " and result memref type " << resultType;
288 
289  // Verify that we have the dynamic dimension operand when it is needed.
290  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
291  return emitError("missing dimension operand for result type ")
292  << resultType;
293  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
294  return emitError("unnecessary dimension operand for result type ")
295  << resultType;
296 
297  return success();
298 }
299 
300 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
301  MLIRContext *context) {
302  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // AllocaScopeOp
307 //===----------------------------------------------------------------------===//
308 
310  bool printBlockTerminators = false;
311 
312  p << ' ';
313  if (!getResults().empty()) {
314  p << " -> (" << getResultTypes() << ")";
315  printBlockTerminators = true;
316  }
317  p << ' ';
318  p.printRegion(getBodyRegion(),
319  /*printEntryBlockArgs=*/false,
320  /*printBlockTerminators=*/printBlockTerminators);
321  p.printOptionalAttrDict((*this)->getAttrs());
322 }
323 
324 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
325  // Create a region for the body.
326  result.regions.reserve(1);
327  Region *bodyRegion = result.addRegion();
328 
329  // Parse optional results type list.
330  if (parser.parseOptionalArrowTypeList(result.types))
331  return failure();
332 
333  // Parse the body region.
334  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
335  return failure();
336  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
337  result.location);
338 
339  // Parse the optional attribute list.
340  if (parser.parseOptionalAttrDict(result.attributes))
341  return failure();
342 
343  return success();
344 }
345 
346 void AllocaScopeOp::getSuccessorRegions(
348  if (!point.isParent()) {
349  regions.push_back(RegionSuccessor(getResults()));
350  return;
351  }
352 
353  regions.push_back(RegionSuccessor(&getBodyRegion()));
354 }
355 
356 /// Given an operation, return whether this op is guaranteed to
357 /// allocate an AutomaticAllocationScopeResource
359  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
360  if (!interface)
361  return false;
362  for (auto res : op->getResults()) {
363  if (auto effect =
364  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
365  if (isa<SideEffects::AutomaticAllocationScopeResource>(
366  effect->getResource()))
367  return true;
368  }
369  }
370  return false;
371 }
372 
373 /// Given an operation, return whether this op itself could
374 /// allocate an AutomaticAllocationScopeResource. Note that
375 /// this will not check whether an operation contained within
376 /// the op can allocate.
378  // This op itself doesn't create a stack allocation,
379  // the inner allocation should be handled separately.
381  return false;
382  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
383  if (!interface)
384  return true;
385  for (auto res : op->getResults()) {
386  if (auto effect =
387  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
388  if (isa<SideEffects::AutomaticAllocationScopeResource>(
389  effect->getResource()))
390  return true;
391  }
392  }
393  return false;
394 }
395 
396 /// Return whether this op is the last non terminating op
397 /// in a region. That is to say, it is in a one-block region
398 /// and is only followed by a terminator. This prevents
399 /// extending the lifetime of allocations.
401  return op->getBlock()->mightHaveTerminator() &&
402  op->getNextNode() == op->getBlock()->getTerminator() &&
403  op->getParentRegion()->hasOneBlock();
404 }
405 
406 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
407 /// or it contains no allocation.
408 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
410 
411  LogicalResult matchAndRewrite(AllocaScopeOp op,
412  PatternRewriter &rewriter) const override {
413  bool hasPotentialAlloca =
414  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
415  if (alloc == op)
416  return WalkResult::advance();
418  return WalkResult::interrupt();
419  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
420  return WalkResult::skip();
421  return WalkResult::advance();
422  }).wasInterrupted();
423 
424  // If this contains no potential allocation, it is always legal to
425  // inline. Otherwise, consider two conditions:
426  if (hasPotentialAlloca) {
427  // If the parent isn't an allocation scope, or we are not the last
428  // non-terminator op in the parent, we will extend the lifetime.
429  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
430  return failure();
431  if (!lastNonTerminatorInRegion(op))
432  return failure();
433  }
434 
435  Block *block = &op.getRegion().front();
436  Operation *terminator = block->getTerminator();
437  ValueRange results = terminator->getOperands();
438  rewriter.inlineBlockBefore(block, op);
439  rewriter.replaceOp(op, results);
440  rewriter.eraseOp(terminator);
441  return success();
442  }
443 };
444 
445 /// Move allocations into an allocation scope, if it is legal to
446 /// move them (e.g. their operands are available at the location
447 /// the op would be moved to).
448 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
450 
451  LogicalResult matchAndRewrite(AllocaScopeOp op,
452  PatternRewriter &rewriter) const override {
453 
454  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
455  return failure();
456 
457  Operation *lastParentWithoutScope = op->getParentOp();
458 
459  if (!lastParentWithoutScope ||
460  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
461  return failure();
462 
463  // Only apply to if this is this last non-terminator
464  // op in the block (lest lifetime be extended) of a one
465  // block region
466  if (!lastNonTerminatorInRegion(op) ||
467  !lastNonTerminatorInRegion(lastParentWithoutScope))
468  return failure();
469 
470  while (!lastParentWithoutScope->getParentOp()
472  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
473  if (!lastParentWithoutScope ||
474  !lastNonTerminatorInRegion(lastParentWithoutScope))
475  return failure();
476  }
477  assert(lastParentWithoutScope->getParentOp()
479 
480  Region *containingRegion = nullptr;
481  for (auto &r : lastParentWithoutScope->getRegions()) {
482  if (r.isAncestor(op->getParentRegion())) {
483  assert(containingRegion == nullptr &&
484  "only one region can contain the op");
485  containingRegion = &r;
486  }
487  }
488  assert(containingRegion && "op must be contained in a region");
489 
490  SmallVector<Operation *> toHoist;
491  op->walk([&](Operation *alloc) {
493  return WalkResult::skip();
494 
495  // If any operand is not defined before the location of
496  // lastParentWithoutScope (i.e. where we would hoist to), skip.
497  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
498  return containingRegion->isAncestor(v.getParentRegion());
499  }))
500  return WalkResult::skip();
501  toHoist.push_back(alloc);
502  return WalkResult::advance();
503  });
504 
505  if (toHoist.empty())
506  return failure();
507  rewriter.setInsertionPoint(lastParentWithoutScope);
508  for (auto *op : toHoist) {
509  auto *cloned = rewriter.clone(*op);
510  rewriter.replaceOp(op, cloned->getResults());
511  }
512  return success();
513  }
514 };
515 
516 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
517  MLIRContext *context) {
518  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // AssumeAlignmentOp
523 //===----------------------------------------------------------------------===//
524 
525 LogicalResult AssumeAlignmentOp::verify() {
526  if (!llvm::isPowerOf2_32(getAlignment()))
527  return emitOpError("alignment must be power of 2");
528  return success();
529 }
530 
531 void AssumeAlignmentOp::getAsmResultNames(
532  function_ref<void(Value, StringRef)> setNameFn) {
533  setNameFn(getResult(), "assume_align");
534 }
535 
536 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
537  auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
538  if (!source)
539  return {};
540  if (source.getAlignment() != getAlignment())
541  return {};
542  return getMemref();
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // CastOp
547 //===----------------------------------------------------------------------===//
548 
549 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
550  setNameFn(getResult(), "cast");
551 }
552 
553 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
554 /// source memref. This is useful to fold a memref.cast into a consuming op
555 /// and implement canonicalization patterns for ops in different dialects that
556 /// may consume the results of memref.cast operations. Such foldable memref.cast
557 /// operations are typically inserted as `view` and `subview` ops are
558 /// canonicalized, to preserve the type compatibility of their uses.
559 ///
560 /// Returns true when all conditions are met:
561 /// 1. source and result are ranked memrefs with strided semantics and same
562 /// element type and rank.
563 /// 2. each of the source's size, offset or stride has more static information
564 /// than the corresponding result's size, offset or stride.
565 ///
566 /// Example 1:
567 /// ```mlir
568 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
569 /// %2 = consumer %1 ... : memref<?x?xf32> ...
570 /// ```
571 ///
572 /// may fold into:
573 ///
574 /// ```mlir
575 /// %2 = consumer %0 ... : memref<8x16xf32> ...
576 /// ```
577 ///
578 /// Example 2:
579 /// ```
580 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
581 /// to memref<?x?xf32>
582 /// consumer %1 : memref<?x?xf32> ...
583 /// ```
584 ///
585 /// may fold into:
586 ///
587 /// ```
588 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
589 /// ```
590 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
591  MemRefType sourceType =
592  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
593  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
594 
595  // Requires ranked MemRefType.
596  if (!sourceType || !resultType)
597  return false;
598 
599  // Requires same elemental type.
600  if (sourceType.getElementType() != resultType.getElementType())
601  return false;
602 
603  // Requires same rank.
604  if (sourceType.getRank() != resultType.getRank())
605  return false;
606 
607  // Only fold casts between strided memref forms.
608  int64_t sourceOffset, resultOffset;
609  SmallVector<int64_t, 4> sourceStrides, resultStrides;
610  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
611  failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
612  return false;
613 
614  // If cast is towards more static sizes along any dimension, don't fold.
615  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
616  auto ss = std::get<0>(it), st = std::get<1>(it);
617  if (ss != st)
618  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
619  return false;
620  }
621 
622  // If cast is towards more static offset along any dimension, don't fold.
623  if (sourceOffset != resultOffset)
624  if (ShapedType::isDynamic(sourceOffset) &&
625  ShapedType::isStatic(resultOffset))
626  return false;
627 
628  // If cast is towards more static strides along any dimension, don't fold.
629  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
630  auto ss = std::get<0>(it), st = std::get<1>(it);
631  if (ss != st)
632  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
633  return false;
634  }
635 
636  return true;
637 }
638 
639 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
640  if (inputs.size() != 1 || outputs.size() != 1)
641  return false;
642  Type a = inputs.front(), b = outputs.front();
643  auto aT = llvm::dyn_cast<MemRefType>(a);
644  auto bT = llvm::dyn_cast<MemRefType>(b);
645 
646  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
647  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
648 
649  if (aT && bT) {
650  if (aT.getElementType() != bT.getElementType())
651  return false;
652  if (aT.getLayout() != bT.getLayout()) {
653  int64_t aOffset, bOffset;
654  SmallVector<int64_t, 4> aStrides, bStrides;
655  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
656  failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
657  aStrides.size() != bStrides.size())
658  return false;
659 
660  // Strides along a dimension/offset are compatible if the value in the
661  // source memref is static and the value in the target memref is the
662  // same. They are also compatible if either one is dynamic (see
663  // description of MemRefCastOp for details).
664  auto checkCompatible = [](int64_t a, int64_t b) {
665  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
666  };
667  if (!checkCompatible(aOffset, bOffset))
668  return false;
669  for (const auto &aStride : enumerate(aStrides))
670  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
671  return false;
672  }
673  if (aT.getMemorySpace() != bT.getMemorySpace())
674  return false;
675 
676  // They must have the same rank, and any specified dimensions must match.
677  if (aT.getRank() != bT.getRank())
678  return false;
679 
680  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
681  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
682  if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
683  aDim != bDim)
684  return false;
685  }
686  return true;
687  } else {
688  if (!aT && !uaT)
689  return false;
690  if (!bT && !ubT)
691  return false;
692  // Unranked to unranked casting is unsupported
693  if (uaT && ubT)
694  return false;
695 
696  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
697  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
698  if (aEltType != bEltType)
699  return false;
700 
701  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
702  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
703  return aMemSpace == bMemSpace;
704  }
705 
706  return false;
707 }
708 
709 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
710  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // CopyOp
715 //===----------------------------------------------------------------------===//
716 
717 namespace {
718 
719 /// Fold memref.copy(%x, %x).
720 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
722 
723  LogicalResult matchAndRewrite(CopyOp copyOp,
724  PatternRewriter &rewriter) const override {
725  if (copyOp.getSource() != copyOp.getTarget())
726  return failure();
727 
728  rewriter.eraseOp(copyOp);
729  return success();
730  }
731 };
732 
733 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
735 
736  static bool isEmptyMemRef(BaseMemRefType type) {
737  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
738  }
739 
740  LogicalResult matchAndRewrite(CopyOp copyOp,
741  PatternRewriter &rewriter) const override {
742  if (isEmptyMemRef(copyOp.getSource().getType()) ||
743  isEmptyMemRef(copyOp.getTarget().getType())) {
744  rewriter.eraseOp(copyOp);
745  return success();
746  }
747 
748  return failure();
749  }
750 };
751 } // namespace
752 
753 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
754  MLIRContext *context) {
755  results.add<FoldEmptyCopy, FoldSelfCopy>(context);
756 }
757 
758 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
759 /// and element type, the cast can be skipped. Such CastOps only cast the layout
760 /// of the type.
761 static LogicalResult FoldCopyOfCast(CopyOp op) {
762  for (OpOperand &operand : op->getOpOperands()) {
763  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
764  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
765  operand.set(castOp.getOperand());
766  return success();
767  }
768  }
769  return failure();
770 }
771 
772 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
774 
775  /// copy(memrefcast) -> copy
776  return FoldCopyOfCast(*this);
777 }
778 
779 //===----------------------------------------------------------------------===//
780 // DeallocOp
781 //===----------------------------------------------------------------------===//
782 
783 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
785  /// dealloc(memrefcast) -> dealloc
786  return foldMemRefCast(*this);
787 }
788 
789 //===----------------------------------------------------------------------===//
790 // DimOp
791 //===----------------------------------------------------------------------===//
792 
793 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
794  setNameFn(getResult(), "dim");
795 }
796 
797 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
798  int64_t index) {
799  auto loc = result.location;
800  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
801  build(builder, result, source, indexValue);
802 }
803 
804 std::optional<int64_t> DimOp::getConstantIndex() {
805  return getConstantIntValue(getIndex());
806 }
807 
808 Speculation::Speculatability DimOp::getSpeculatability() {
809  auto constantIndex = getConstantIndex();
810  if (!constantIndex)
812 
813  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
814  if (!rankedSourceType)
816 
817  if (rankedSourceType.getRank() <= constantIndex)
819 
821 }
822 
823 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
824  SetIntLatticeFn setResultRange) {
825  setResultRange(getResult(),
826  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
827 }
828 
829 /// Return a map with key being elements in `vals` and data being number of
830 /// occurences of it. Use std::map, since the `vals` here are strides and the
831 /// dynamic stride value is the same as the tombstone value for
832 /// `DenseMap<int64_t>`.
833 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
834  std::map<int64_t, unsigned> numOccurences;
835  for (auto val : vals)
836  numOccurences[val]++;
837  return numOccurences;
838 }
839 
840 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
841 /// to be a subset of `originalType` with some `1` entries erased, return the
842 /// set of indices that specifies which of the entries of `originalShape` are
843 /// dropped to obtain `reducedShape`.
844 /// This accounts for cases where there are multiple unit-dims, but only a
845 /// subset of those are dropped. For MemRefTypes these can be disambiguated
846 /// using the strides. If a dimension is dropped the stride must be dropped too.
847 static FailureOr<llvm::SmallBitVector>
848 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
849  ArrayRef<OpFoldResult> sizes) {
850  llvm::SmallBitVector unusedDims(originalType.getRank());
851  if (originalType.getRank() == reducedType.getRank())
852  return unusedDims;
853 
854  for (const auto &dim : llvm::enumerate(sizes))
855  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
856  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
857  unusedDims.set(dim.index());
858 
859  // Early exit for the case where the number of unused dims matches the number
860  // of ranks reduced.
861  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
862  originalType.getRank())
863  return unusedDims;
864 
865  SmallVector<int64_t> originalStrides, candidateStrides;
866  int64_t originalOffset, candidateOffset;
867  if (failed(
868  originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
869  failed(
870  reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
871  return failure();
872 
873  // For memrefs, a dimension is truly dropped if its corresponding stride is
874  // also dropped. This is particularly important when more than one of the dims
875  // is 1. Track the number of occurences of the strides in the original type
876  // and the candidate type. For each unused dim that stride should not be
877  // present in the candidate type. Note that there could be multiple dimensions
878  // that have the same size. We dont need to exactly figure out which dim
879  // corresponds to which stride, we just need to verify that the number of
880  // reptitions of a stride in the original + number of unused dims with that
881  // stride == number of repititions of a stride in the candidate.
882  std::map<int64_t, unsigned> currUnaccountedStrides =
883  getNumOccurences(originalStrides);
884  std::map<int64_t, unsigned> candidateStridesNumOccurences =
885  getNumOccurences(candidateStrides);
886  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
887  if (!unusedDims.test(dim))
888  continue;
889  int64_t originalStride = originalStrides[dim];
890  if (currUnaccountedStrides[originalStride] >
891  candidateStridesNumOccurences[originalStride]) {
892  // This dim can be treated as dropped.
893  currUnaccountedStrides[originalStride]--;
894  continue;
895  }
896  if (currUnaccountedStrides[originalStride] ==
897  candidateStridesNumOccurences[originalStride]) {
898  // The stride for this is not dropped. Keep as is.
899  unusedDims.reset(dim);
900  continue;
901  }
902  if (currUnaccountedStrides[originalStride] <
903  candidateStridesNumOccurences[originalStride]) {
904  // This should never happen. Cant have a stride in the reduced rank type
905  // that wasnt in the original one.
906  return failure();
907  }
908  }
909 
910  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
911  originalType.getRank())
912  return failure();
913  return unusedDims;
914 }
915 
916 llvm::SmallBitVector SubViewOp::getDroppedDims() {
917  MemRefType sourceType = getSourceType();
918  MemRefType resultType = getType();
919  FailureOr<llvm::SmallBitVector> unusedDims =
920  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
921  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
922  return *unusedDims;
923 }
924 
925 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
926  // All forms of folding require a known index.
927  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
928  if (!index)
929  return {};
930 
931  // Folding for unranked types (UnrankedMemRefType) is not supported.
932  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
933  if (!memrefType)
934  return {};
935 
936  // Out of bound indices produce undefined behavior but are still valid IR.
937  // Don't choke on them.
938  int64_t indexVal = index.getInt();
939  if (indexVal < 0 || indexVal >= memrefType.getRank())
940  return {};
941 
942  // Fold if the shape extent along the given index is known.
943  if (!memrefType.isDynamicDim(index.getInt())) {
944  Builder builder(getContext());
945  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
946  }
947 
948  // The size at the given index is now known to be a dynamic size.
949  unsigned unsignedIndex = index.getValue().getZExtValue();
950 
951  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
952  Operation *definingOp = getSource().getDefiningOp();
953 
954  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
955  return *(alloc.getDynamicSizes().begin() +
956  memrefType.getDynamicDimIndex(unsignedIndex));
957 
958  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
959  return *(alloca.getDynamicSizes().begin() +
960  memrefType.getDynamicDimIndex(unsignedIndex));
961 
962  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
963  return *(view.getDynamicSizes().begin() +
964  memrefType.getDynamicDimIndex(unsignedIndex));
965 
966  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
967  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
968  unsigned resultIndex = 0;
969  unsigned sourceRank = subview.getSourceType().getRank();
970  unsigned sourceIndex = 0;
971  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
972  if (unusedDims.test(i))
973  continue;
974  if (resultIndex == unsignedIndex) {
975  sourceIndex = i;
976  break;
977  }
978  resultIndex++;
979  }
980  assert(subview.isDynamicSize(sourceIndex) &&
981  "expected dynamic subview size");
982  return subview.getDynamicSize(sourceIndex);
983  }
984 
985  if (auto sizeInterface =
986  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
987  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
988  "Expected dynamic subview size");
989  return sizeInterface.getDynamicSize(unsignedIndex);
990  }
991 
992  // dim(memrefcast) -> dim
993  if (succeeded(foldMemRefCast(*this)))
994  return getResult();
995 
996  return {};
997 }
998 
999 namespace {
1000 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1001 /// operand.
1002 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1004 
1005  LogicalResult matchAndRewrite(DimOp dim,
1006  PatternRewriter &rewriter) const override {
1007  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1008 
1009  if (!reshape)
1010  return rewriter.notifyMatchFailure(
1011  dim, "Dim op is not defined by a reshape op.");
1012 
1013  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1014  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1015  // cheaply check that either of the following conditions hold:
1016  // 1. dim.getIndex() is defined in the same block as reshape but before
1017  // reshape.
1018  // 2. dim.getIndex() is defined in a parent block of
1019  // reshape.
1020 
1021  // Check condition 1
1022  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1023  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1024  if (reshape->isBeforeInBlock(definingOp)) {
1025  return rewriter.notifyMatchFailure(
1026  dim,
1027  "dim.getIndex is not defined before reshape in the same block.");
1028  }
1029  } // else dim.getIndex is a block argument to reshape->getBlock and
1030  // dominates reshape
1031  } // Check condition 2
1032  else if (dim->getBlock() != reshape->getBlock() &&
1033  !dim.getIndex().getParentRegion()->isProperAncestor(
1034  reshape->getParentRegion())) {
1035  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1036  // already know dim.getIndex() dominates reshape without calling
1037  // `isProperAncestor`
1038  return rewriter.notifyMatchFailure(
1039  dim, "dim.getIndex does not dominate reshape.");
1040  }
1041 
1042  // Place the load directly after the reshape to ensure that the shape memref
1043  // was not mutated.
1044  rewriter.setInsertionPointAfter(reshape);
1045  Location loc = dim.getLoc();
1046  Value load =
1047  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1048  if (load.getType() != dim.getType())
1049  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1050  rewriter.replaceOp(dim, load);
1051  return success();
1052  }
1053 };
1054 
1055 } // namespace
1056 
1057 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1058  MLIRContext *context) {
1059  results.add<DimOfMemRefReshape>(context);
1060 }
1061 
1062 // ---------------------------------------------------------------------------
1063 // DmaStartOp
1064 // ---------------------------------------------------------------------------
1065 
1066 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1067  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1068  ValueRange destIndices, Value numElements,
1069  Value tagMemRef, ValueRange tagIndices, Value stride,
1070  Value elementsPerStride) {
1071  result.addOperands(srcMemRef);
1072  result.addOperands(srcIndices);
1073  result.addOperands(destMemRef);
1074  result.addOperands(destIndices);
1075  result.addOperands({numElements, tagMemRef});
1076  result.addOperands(tagIndices);
1077  if (stride)
1078  result.addOperands({stride, elementsPerStride});
1079 }
1080 
1082  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1083  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1084  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1085  if (isStrided())
1086  p << ", " << getStride() << ", " << getNumElementsPerStride();
1087 
1088  p.printOptionalAttrDict((*this)->getAttrs());
1089  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1090  << ", " << getTagMemRef().getType();
1091 }
1092 
1093 // Parse DmaStartOp.
1094 // Ex:
1095 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1096 // %tag[%index], %stride, %num_elt_per_stride :
1097 // : memref<3076 x f32, 0>,
1098 // memref<1024 x f32, 2>,
1099 // memref<1 x i32>
1100 //
1101 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1102  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1104  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1106  OpAsmParser::UnresolvedOperand numElementsInfo;
1107  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1110 
1111  SmallVector<Type, 3> types;
1112  auto indexType = parser.getBuilder().getIndexType();
1113 
1114  // Parse and resolve the following list of operands:
1115  // *) source memref followed by its indices (in square brackets).
1116  // *) destination memref followed by its indices (in square brackets).
1117  // *) dma size in KiB.
1118  if (parser.parseOperand(srcMemRefInfo) ||
1119  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1120  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1121  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1122  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1123  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1124  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1125  return failure();
1126 
1127  // Parse optional stride and elements per stride.
1128  if (parser.parseTrailingOperandList(strideInfo))
1129  return failure();
1130 
1131  bool isStrided = strideInfo.size() == 2;
1132  if (!strideInfo.empty() && !isStrided) {
1133  return parser.emitError(parser.getNameLoc(),
1134  "expected two stride related operands");
1135  }
1136 
1137  if (parser.parseColonTypeList(types))
1138  return failure();
1139  if (types.size() != 3)
1140  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1141 
1142  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1143  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1144  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1145  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1146  // size should be an index.
1147  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1148  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1149  // tag indices should be index.
1150  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1151  return failure();
1152 
1153  if (isStrided) {
1154  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1155  return failure();
1156  }
1157 
1158  return success();
1159 }
1160 
1161 LogicalResult DmaStartOp::verify() {
1162  unsigned numOperands = getNumOperands();
1163 
1164  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1165  // the number of elements.
1166  if (numOperands < 4)
1167  return emitOpError("expected at least 4 operands");
1168 
1169  // Check types of operands. The order of these calls is important: the later
1170  // calls rely on some type properties to compute the operand position.
1171  // 1. Source memref.
1172  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1173  return emitOpError("expected source to be of memref type");
1174  if (numOperands < getSrcMemRefRank() + 4)
1175  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1176  << " operands";
1177  if (!getSrcIndices().empty() &&
1178  !llvm::all_of(getSrcIndices().getTypes(),
1179  [](Type t) { return t.isIndex(); }))
1180  return emitOpError("expected source indices to be of index type");
1181 
1182  // 2. Destination memref.
1183  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1184  return emitOpError("expected destination to be of memref type");
1185  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1186  if (numOperands < numExpectedOperands)
1187  return emitOpError() << "expected at least " << numExpectedOperands
1188  << " operands";
1189  if (!getDstIndices().empty() &&
1190  !llvm::all_of(getDstIndices().getTypes(),
1191  [](Type t) { return t.isIndex(); }))
1192  return emitOpError("expected destination indices to be of index type");
1193 
1194  // 3. Number of elements.
1195  if (!getNumElements().getType().isIndex())
1196  return emitOpError("expected num elements to be of index type");
1197 
1198  // 4. Tag memref.
1199  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1200  return emitOpError("expected tag to be of memref type");
1201  numExpectedOperands += getTagMemRefRank();
1202  if (numOperands < numExpectedOperands)
1203  return emitOpError() << "expected at least " << numExpectedOperands
1204  << " operands";
1205  if (!getTagIndices().empty() &&
1206  !llvm::all_of(getTagIndices().getTypes(),
1207  [](Type t) { return t.isIndex(); }))
1208  return emitOpError("expected tag indices to be of index type");
1209 
1210  // Optional stride-related operands must be either both present or both
1211  // absent.
1212  if (numOperands != numExpectedOperands &&
1213  numOperands != numExpectedOperands + 2)
1214  return emitOpError("incorrect number of operands");
1215 
1216  // 5. Strides.
1217  if (isStrided()) {
1218  if (!getStride().getType().isIndex() ||
1219  !getNumElementsPerStride().getType().isIndex())
1220  return emitOpError(
1221  "expected stride and num elements per stride to be of type index");
1222  }
1223 
1224  return success();
1225 }
1226 
1227 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1228  SmallVectorImpl<OpFoldResult> &results) {
1229  /// dma_start(memrefcast) -> dma_start
1230  return foldMemRefCast(*this);
1231 }
1232 
1233 // ---------------------------------------------------------------------------
1234 // DmaWaitOp
1235 // ---------------------------------------------------------------------------
1236 
1237 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1238  SmallVectorImpl<OpFoldResult> &results) {
1239  /// dma_wait(memrefcast) -> dma_wait
1240  return foldMemRefCast(*this);
1241 }
1242 
1243 LogicalResult DmaWaitOp::verify() {
1244  // Check that the number of tag indices matches the tagMemRef rank.
1245  unsigned numTagIndices = getTagIndices().size();
1246  unsigned tagMemRefRank = getTagMemRefRank();
1247  if (numTagIndices != tagMemRefRank)
1248  return emitOpError() << "expected tagIndices to have the same number of "
1249  "elements as the tagMemRef rank, expected "
1250  << tagMemRefRank << ", but got " << numTagIndices;
1251  return success();
1252 }
1253 
1254 //===----------------------------------------------------------------------===//
1255 // ExtractAlignedPointerAsIndexOp
1256 //===----------------------------------------------------------------------===//
1257 
1258 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1259  function_ref<void(Value, StringRef)> setNameFn) {
1260  setNameFn(getResult(), "intptr");
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // ExtractStridedMetadataOp
1265 //===----------------------------------------------------------------------===//
1266 
1267 /// The number and type of the results are inferred from the
1268 /// shape of the source.
1269 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1270  MLIRContext *context, std::optional<Location> location,
1271  ExtractStridedMetadataOp::Adaptor adaptor,
1272  SmallVectorImpl<Type> &inferredReturnTypes) {
1273  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1274  if (!sourceType)
1275  return failure();
1276 
1277  unsigned sourceRank = sourceType.getRank();
1278  IndexType indexType = IndexType::get(context);
1279  auto memrefType =
1280  MemRefType::get({}, sourceType.getElementType(),
1281  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1282  // Base.
1283  inferredReturnTypes.push_back(memrefType);
1284  // Offset.
1285  inferredReturnTypes.push_back(indexType);
1286  // Sizes and strides.
1287  for (unsigned i = 0; i < sourceRank * 2; ++i)
1288  inferredReturnTypes.push_back(indexType);
1289  return success();
1290 }
1291 
1292 void ExtractStridedMetadataOp::getAsmResultNames(
1293  function_ref<void(Value, StringRef)> setNameFn) {
1294  setNameFn(getBaseBuffer(), "base_buffer");
1295  setNameFn(getOffset(), "offset");
1296  // For multi-result to work properly with pretty names and packed syntax `x:3`
1297  // we can only give a pretty name to the first value in the pack.
1298  if (!getSizes().empty()) {
1299  setNameFn(getSizes().front(), "sizes");
1300  setNameFn(getStrides().front(), "strides");
1301  }
1302 }
1303 
1304 /// Helper function to perform the replacement of all constant uses of `values`
1305 /// by a materialized constant extracted from `maybeConstants`.
1306 /// `values` and `maybeConstants` are expected to have the same size.
1307 template <typename Container>
1308 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1309  Container values,
1310  ArrayRef<OpFoldResult> maybeConstants) {
1311  assert(values.size() == maybeConstants.size() &&
1312  " expected values and maybeConstants of the same size");
1313  bool atLeastOneReplacement = false;
1314  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1315  // Don't materialize a constant if there are no uses: this would indice
1316  // infinite loops in the driver.
1317  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1318  continue;
1319  assert(isa<Attribute>(maybeConstant) &&
1320  "The constified value should be either unchanged (i.e., == result) "
1321  "or a constant");
1322  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1323  loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1324  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1325  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1326  // yet.
1327  op->replaceUsesOfWith(result, constantVal);
1328  atLeastOneReplacement = true;
1329  }
1330  }
1331  return atLeastOneReplacement;
1332 }
1333 
1334 LogicalResult
1335 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1336  SmallVectorImpl<OpFoldResult> &results) {
1337  OpBuilder builder(*this);
1338 
1339  bool atLeastOneReplacement = replaceConstantUsesOf(
1340  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1341  getConstifiedMixedOffset());
1342  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1343  getConstifiedMixedSizes());
1344  atLeastOneReplacement |= replaceConstantUsesOf(
1345  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1346 
1347  return success(atLeastOneReplacement);
1348 }
1349 
1350 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1351  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1352  constifyIndexValues(values, getSource().getType().getShape());
1353  return values;
1354 }
1355 
1357 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1358  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1359  SmallVector<int64_t> staticValues;
1360  int64_t unused;
1361  LogicalResult status =
1362  getSource().getType().getStridesAndOffset(staticValues, unused);
1363  (void)status;
1364  assert(succeeded(status) && "could not get strides from type");
1365  constifyIndexValues(values, staticValues);
1366  return values;
1367 }
1368 
1369 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1370  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1371  SmallVector<OpFoldResult> values(1, offsetOfr);
1372  SmallVector<int64_t> staticValues, unused;
1373  int64_t offset;
1374  LogicalResult status =
1375  getSource().getType().getStridesAndOffset(unused, offset);
1376  (void)status;
1377  assert(succeeded(status) && "could not get offset from type");
1378  staticValues.push_back(offset);
1379  constifyIndexValues(values, staticValues);
1380  return values[0];
1381 }
1382 
1383 //===----------------------------------------------------------------------===//
1384 // GenericAtomicRMWOp
1385 //===----------------------------------------------------------------------===//
1386 
1387 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1388  Value memref, ValueRange ivs) {
1389  OpBuilder::InsertionGuard g(builder);
1390  result.addOperands(memref);
1391  result.addOperands(ivs);
1392 
1393  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1394  Type elementType = memrefType.getElementType();
1395  result.addTypes(elementType);
1396 
1397  Region *bodyRegion = result.addRegion();
1398  builder.createBlock(bodyRegion);
1399  bodyRegion->addArgument(elementType, memref.getLoc());
1400  }
1401 }
1402 
1403 LogicalResult GenericAtomicRMWOp::verify() {
1404  auto &body = getRegion();
1405  if (body.getNumArguments() != 1)
1406  return emitOpError("expected single number of entry block arguments");
1407 
1408  if (getResult().getType() != body.getArgument(0).getType())
1409  return emitOpError("expected block argument of the same type result type");
1410 
1411  bool hasSideEffects =
1412  body.walk([&](Operation *nestedOp) {
1413  if (isMemoryEffectFree(nestedOp))
1414  return WalkResult::advance();
1415  nestedOp->emitError(
1416  "body of 'memref.generic_atomic_rmw' should contain "
1417  "only operations with no side effects");
1418  return WalkResult::interrupt();
1419  })
1420  .wasInterrupted();
1421  return hasSideEffects ? failure() : success();
1422 }
1423 
1424 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1425  OperationState &result) {
1427  Type memrefType;
1429 
1430  Type indexType = parser.getBuilder().getIndexType();
1431  if (parser.parseOperand(memref) ||
1433  parser.parseColonType(memrefType) ||
1434  parser.resolveOperand(memref, memrefType, result.operands) ||
1435  parser.resolveOperands(ivs, indexType, result.operands))
1436  return failure();
1437 
1438  Region *body = result.addRegion();
1439  if (parser.parseRegion(*body, {}) ||
1440  parser.parseOptionalAttrDict(result.attributes))
1441  return failure();
1442  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1443  return success();
1444 }
1445 
1447  p << ' ' << getMemref() << "[" << getIndices()
1448  << "] : " << getMemref().getType() << ' ';
1449  p.printRegion(getRegion());
1450  p.printOptionalAttrDict((*this)->getAttrs());
1451 }
1452 
1453 //===----------------------------------------------------------------------===//
1454 // AtomicYieldOp
1455 //===----------------------------------------------------------------------===//
1456 
1457 LogicalResult AtomicYieldOp::verify() {
1458  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1459  Type resultType = getResult().getType();
1460  if (parentType != resultType)
1461  return emitOpError() << "types mismatch between yield op: " << resultType
1462  << " and its parent: " << parentType;
1463  return success();
1464 }
1465 
1466 //===----------------------------------------------------------------------===//
1467 // GlobalOp
1468 //===----------------------------------------------------------------------===//
1469 
1471  TypeAttr type,
1472  Attribute initialValue) {
1473  p << type;
1474  if (!op.isExternal()) {
1475  p << " = ";
1476  if (op.isUninitialized())
1477  p << "uninitialized";
1478  else
1479  p.printAttributeWithoutType(initialValue);
1480  }
1481 }
1482 
1483 static ParseResult
1485  Attribute &initialValue) {
1486  Type type;
1487  if (parser.parseType(type))
1488  return failure();
1489 
1490  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1491  if (!memrefType || !memrefType.hasStaticShape())
1492  return parser.emitError(parser.getNameLoc())
1493  << "type should be static shaped memref, but got " << type;
1494  typeAttr = TypeAttr::get(type);
1495 
1496  if (parser.parseOptionalEqual())
1497  return success();
1498 
1499  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1500  initialValue = UnitAttr::get(parser.getContext());
1501  return success();
1502  }
1503 
1504  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1505  if (parser.parseAttribute(initialValue, tensorType))
1506  return failure();
1507  if (!llvm::isa<ElementsAttr>(initialValue))
1508  return parser.emitError(parser.getNameLoc())
1509  << "initial value should be a unit or elements attribute";
1510  return success();
1511 }
1512 
1513 LogicalResult GlobalOp::verify() {
1514  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1515  if (!memrefType || !memrefType.hasStaticShape())
1516  return emitOpError("type should be static shaped memref, but got ")
1517  << getType();
1518 
1519  // Verify that the initial value, if present, is either a unit attribute or
1520  // an elements attribute.
1521  if (getInitialValue().has_value()) {
1522  Attribute initValue = getInitialValue().value();
1523  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1524  return emitOpError("initial value should be a unit or elements "
1525  "attribute, but got ")
1526  << initValue;
1527 
1528  // Check that the type of the initial value is compatible with the type of
1529  // the global variable.
1530  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1531  // Check the element types match.
1532  auto initElementType =
1533  cast<TensorType>(elementsAttr.getType()).getElementType();
1534  auto memrefElementType = memrefType.getElementType();
1535 
1536  if (initElementType != memrefElementType)
1537  return emitOpError("initial value element expected to be of type ")
1538  << memrefElementType << ", but was of type " << initElementType;
1539 
1540  // Check the shapes match, given that memref globals can only produce
1541  // statically shaped memrefs and elements literal type must have a static
1542  // shape we can assume both types are shaped.
1543  auto initShape = elementsAttr.getShapedType().getShape();
1544  auto memrefShape = memrefType.getShape();
1545  if (initShape != memrefShape)
1546  return emitOpError("initial value shape expected to be ")
1547  << memrefShape << " but was " << initShape;
1548  }
1549  }
1550 
1551  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1552  uint64_t alignment = *alignAttr;
1553 
1554  if (!llvm::isPowerOf2_64(alignment))
1555  return emitError() << "alignment attribute value " << alignment
1556  << " is not a power of 2";
1557  }
1558 
1559  // TODO: verify visibility for declarations.
1560  return success();
1561 }
1562 
1563 ElementsAttr GlobalOp::getConstantInitValue() {
1564  auto initVal = getInitialValue();
1565  if (getConstant() && initVal.has_value())
1566  return llvm::cast<ElementsAttr>(initVal.value());
1567  return {};
1568 }
1569 
1570 //===----------------------------------------------------------------------===//
1571 // GetGlobalOp
1572 //===----------------------------------------------------------------------===//
1573 
1574 LogicalResult
1575 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1576  // Verify that the result type is same as the type of the referenced
1577  // memref.global op.
1578  auto global =
1579  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1580  if (!global)
1581  return emitOpError("'")
1582  << getName() << "' does not reference a valid global memref";
1583 
1584  Type resultType = getResult().getType();
1585  if (global.getType() != resultType)
1586  return emitOpError("result type ")
1587  << resultType << " does not match type " << global.getType()
1588  << " of the global memref @" << getName();
1589  return success();
1590 }
1591 
1592 //===----------------------------------------------------------------------===//
1593 // LoadOp
1594 //===----------------------------------------------------------------------===//
1595 
1596 LogicalResult LoadOp::verify() {
1597  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1598  return emitOpError("incorrect number of indices for load, expected ")
1599  << getMemRefType().getRank() << " but got " << getIndices().size();
1600  }
1601  return success();
1602 }
1603 
1604 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1605  /// load(memrefcast) -> load
1606  if (succeeded(foldMemRefCast(*this)))
1607  return getResult();
1608  return OpFoldResult();
1609 }
1610 
1611 //===----------------------------------------------------------------------===//
1612 // MemorySpaceCastOp
1613 //===----------------------------------------------------------------------===//
1614 
1615 void MemorySpaceCastOp::getAsmResultNames(
1616  function_ref<void(Value, StringRef)> setNameFn) {
1617  setNameFn(getResult(), "memspacecast");
1618 }
1619 
1620 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1621  if (inputs.size() != 1 || outputs.size() != 1)
1622  return false;
1623  Type a = inputs.front(), b = outputs.front();
1624  auto aT = llvm::dyn_cast<MemRefType>(a);
1625  auto bT = llvm::dyn_cast<MemRefType>(b);
1626 
1627  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1628  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1629 
1630  if (aT && bT) {
1631  if (aT.getElementType() != bT.getElementType())
1632  return false;
1633  if (aT.getLayout() != bT.getLayout())
1634  return false;
1635  if (aT.getShape() != bT.getShape())
1636  return false;
1637  return true;
1638  }
1639  if (uaT && ubT) {
1640  return uaT.getElementType() == ubT.getElementType();
1641  }
1642  return false;
1643 }
1644 
1645 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1646  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1647  // t2)
1648  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1649  getSourceMutable().assign(parentCast.getSource());
1650  return getResult();
1651  }
1652  return Value{};
1653 }
1654 
1655 //===----------------------------------------------------------------------===//
1656 // PrefetchOp
1657 //===----------------------------------------------------------------------===//
1658 
1660  p << " " << getMemref() << '[';
1662  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1663  p << ", locality<" << getLocalityHint();
1664  p << ">, " << (getIsDataCache() ? "data" : "instr");
1666  (*this)->getAttrs(),
1667  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1668  p << " : " << getMemRefType();
1669 }
1670 
1671 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1672  OpAsmParser::UnresolvedOperand memrefInfo;
1674  IntegerAttr localityHint;
1675  MemRefType type;
1676  StringRef readOrWrite, cacheType;
1677 
1678  auto indexTy = parser.getBuilder().getIndexType();
1679  auto i32Type = parser.getBuilder().getIntegerType(32);
1680  if (parser.parseOperand(memrefInfo) ||
1681  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1682  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1683  parser.parseComma() || parser.parseKeyword("locality") ||
1684  parser.parseLess() ||
1685  parser.parseAttribute(localityHint, i32Type, "localityHint",
1686  result.attributes) ||
1687  parser.parseGreater() || parser.parseComma() ||
1688  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1689  parser.resolveOperand(memrefInfo, type, result.operands) ||
1690  parser.resolveOperands(indexInfo, indexTy, result.operands))
1691  return failure();
1692 
1693  if (readOrWrite != "read" && readOrWrite != "write")
1694  return parser.emitError(parser.getNameLoc(),
1695  "rw specifier has to be 'read' or 'write'");
1696  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1697  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1698 
1699  if (cacheType != "data" && cacheType != "instr")
1700  return parser.emitError(parser.getNameLoc(),
1701  "cache type has to be 'data' or 'instr'");
1702 
1703  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1704  parser.getBuilder().getBoolAttr(cacheType == "data"));
1705 
1706  return success();
1707 }
1708 
1709 LogicalResult PrefetchOp::verify() {
1710  if (getNumOperands() != 1 + getMemRefType().getRank())
1711  return emitOpError("too few indices");
1712 
1713  return success();
1714 }
1715 
1716 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1717  SmallVectorImpl<OpFoldResult> &results) {
1718  // prefetch(memrefcast) -> prefetch
1719  return foldMemRefCast(*this);
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 // RankOp
1724 //===----------------------------------------------------------------------===//
1725 
1726 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1727  // Constant fold rank when the rank of the operand is known.
1728  auto type = getOperand().getType();
1729  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1730  if (shapedType && shapedType.hasRank())
1731  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1732  return IntegerAttr();
1733 }
1734 
1735 //===----------------------------------------------------------------------===//
1736 // ReinterpretCastOp
1737 //===----------------------------------------------------------------------===//
1738 
1739 void ReinterpretCastOp::getAsmResultNames(
1740  function_ref<void(Value, StringRef)> setNameFn) {
1741  setNameFn(getResult(), "reinterpret_cast");
1742 }
1743 
1744 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1745 /// `staticSizes` and `staticStrides` are automatically filled with
1746 /// source-memref-rank sentinel values that encode dynamic entries.
1747 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1748  MemRefType resultType, Value source,
1749  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1750  ArrayRef<OpFoldResult> strides,
1751  ArrayRef<NamedAttribute> attrs) {
1752  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1753  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1754  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1755  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1756  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1757  result.addAttributes(attrs);
1758  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1759  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1760  b.getDenseI64ArrayAttr(staticSizes),
1761  b.getDenseI64ArrayAttr(staticStrides));
1762 }
1763 
1764 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1765  Value source, OpFoldResult offset,
1766  ArrayRef<OpFoldResult> sizes,
1767  ArrayRef<OpFoldResult> strides,
1768  ArrayRef<NamedAttribute> attrs) {
1769  auto sourceType = cast<BaseMemRefType>(source.getType());
1770  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1771  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1772  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1773  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1774  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1775  auto stridedLayout = StridedLayoutAttr::get(
1776  b.getContext(), staticOffsets.front(), staticStrides);
1777  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1778  stridedLayout, sourceType.getMemorySpace());
1779  build(b, result, resultType, source, offset, sizes, strides, attrs);
1780 }
1781 
1782 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1783  MemRefType resultType, Value source,
1784  int64_t offset, ArrayRef<int64_t> sizes,
1785  ArrayRef<int64_t> strides,
1786  ArrayRef<NamedAttribute> attrs) {
1787  SmallVector<OpFoldResult> sizeValues =
1788  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1789  return b.getI64IntegerAttr(v);
1790  }));
1791  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1792  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1793  return b.getI64IntegerAttr(v);
1794  }));
1795  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1796  strideValues, attrs);
1797 }
1798 
1799 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1800  MemRefType resultType, Value source, Value offset,
1801  ValueRange sizes, ValueRange strides,
1802  ArrayRef<NamedAttribute> attrs) {
1803  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1804  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1805  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1806  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1807  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1808 }
1809 
1810 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1811 // completed automatically, like we have for subview and extract_slice.
1812 LogicalResult ReinterpretCastOp::verify() {
1813  // The source and result memrefs should be in the same memory space.
1814  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1815  auto resultType = llvm::cast<MemRefType>(getType());
1816  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1817  return emitError("different memory spaces specified for source type ")
1818  << srcType << " and result memref type " << resultType;
1819  if (srcType.getElementType() != resultType.getElementType())
1820  return emitError("different element types specified for source type ")
1821  << srcType << " and result memref type " << resultType;
1822 
1823  // Match sizes in result memref type and in static_sizes attribute.
1824  for (auto [idx, resultSize, expectedSize] :
1825  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1826  if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1827  return emitError("expected result type with size = ")
1828  << (ShapedType::isDynamic(expectedSize)
1829  ? std::string("dynamic")
1830  : std::to_string(expectedSize))
1831  << " instead of " << resultSize << " in dim = " << idx;
1832  }
1833 
1834  // Match offset and strides in static_offset and static_strides attributes. If
1835  // result memref type has no affine map specified, this will assume an
1836  // identity layout.
1837  int64_t resultOffset;
1838  SmallVector<int64_t, 4> resultStrides;
1839  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1840  return emitError("expected result type to have strided layout but found ")
1841  << resultType;
1842 
1843  // Match offset in result memref type and in static_offsets attribute.
1844  int64_t expectedOffset = getStaticOffsets().front();
1845  if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1846  return emitError("expected result type with offset = ")
1847  << (ShapedType::isDynamic(expectedOffset)
1848  ? std::string("dynamic")
1849  : std::to_string(expectedOffset))
1850  << " instead of " << resultOffset;
1851 
1852  // Match strides in result memref type and in static_strides attribute.
1853  for (auto [idx, resultStride, expectedStride] :
1854  llvm::enumerate(resultStrides, getStaticStrides())) {
1855  if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1856  return emitError("expected result type with stride = ")
1857  << (ShapedType::isDynamic(expectedStride)
1858  ? std::string("dynamic")
1859  : std::to_string(expectedStride))
1860  << " instead of " << resultStride << " in dim = " << idx;
1861  }
1862 
1863  return success();
1864 }
1865 
1866 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1867  Value src = getSource();
1868  auto getPrevSrc = [&]() -> Value {
1869  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1870  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1871  return prev.getSource();
1872 
1873  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1874  if (auto prev = src.getDefiningOp<CastOp>())
1875  return prev.getSource();
1876 
1877  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1878  // are 0.
1879  if (auto prev = src.getDefiningOp<SubViewOp>())
1880  if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
1881  return prev.getSource();
1882 
1883  return nullptr;
1884  };
1885 
1886  if (auto prevSrc = getPrevSrc()) {
1887  getSourceMutable().assign(prevSrc);
1888  return getResult();
1889  }
1890 
1891  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1892  if (ShapedType::isStaticShape(getType().getShape()) &&
1893  src.getType() == getType() && getStaticOffsets().front() == 0) {
1894  return src;
1895  }
1896 
1897  return nullptr;
1898 }
1899 
1900 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1902  constifyIndexValues(values, getType().getShape());
1903  return values;
1904 }
1905 
1906 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1907  SmallVector<OpFoldResult> values = getMixedStrides();
1908  SmallVector<int64_t> staticValues;
1909  int64_t unused;
1910  LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
1911  (void)status;
1912  assert(succeeded(status) && "could not get strides from type");
1913  constifyIndexValues(values, staticValues);
1914  return values;
1915 }
1916 
1917 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1918  SmallVector<OpFoldResult> values = getMixedOffsets();
1919  assert(values.size() == 1 &&
1920  "reinterpret_cast must have one and only one offset");
1921  SmallVector<int64_t> staticValues, unused;
1922  int64_t offset;
1923  LogicalResult status = getType().getStridesAndOffset(unused, offset);
1924  (void)status;
1925  assert(succeeded(status) && "could not get offset from type");
1926  staticValues.push_back(offset);
1927  constifyIndexValues(values, staticValues);
1928  return values[0];
1929 }
1930 
1931 namespace {
1932 /// Replace the sequence:
1933 /// ```
1934 /// base, offset, sizes, strides = extract_strided_metadata src
1935 /// dst = reinterpret_cast base to offset, sizes, strides
1936 /// ```
1937 /// With
1938 ///
1939 /// ```
1940 /// dst = memref.cast src
1941 /// ```
1942 ///
1943 /// Note: The cast operation is only inserted when the type of dst and src
1944 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
1945 ///
1946 /// This pattern also matches when the offset, sizes, and strides don't come
1947 /// directly from the `extract_strided_metadata`'s results but it can be
1948 /// statically proven that they would hold the same values.
1949 ///
1950 /// For instance, the following sequence would be replaced:
1951 /// ```
1952 /// base, offset, sizes, strides =
1953 /// extract_strided_metadata memref : memref<3x4xty>
1954 /// dst = reinterpret_cast base to 0, [3, 4], strides
1955 /// ```
1956 /// Because we know (thanks to the type of the input memref) that variable
1957 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
1958 ///
1959 /// Similarly, the following sequence would be replaced:
1960 /// ```
1961 /// c0 = arith.constant 0
1962 /// c4 = arith.constant 4
1963 /// base, offset, sizes, strides =
1964 /// extract_strided_metadata memref : memref<3x4xty>
1965 /// dst = reinterpret_cast base to c0, [3, c4], strides
1966 /// ```
1967 /// Because we know that `offset`and `c0` will hold 0
1968 /// and `c4` will hold 4.
1969 ///
1970 /// If the pattern above does not match, the input of the
1971 /// extract_strided_metadata is always folded into the input of the
1972 /// reinterpret_cast operator. This allows for dead code elimination to get rid
1973 /// of the extract_strided_metadata in some cases.
1974 struct ReinterpretCastOpExtractStridedMetadataFolder
1975  : public OpRewritePattern<ReinterpretCastOp> {
1976 public:
1978 
1979  LogicalResult matchAndRewrite(ReinterpretCastOp op,
1980  PatternRewriter &rewriter) const override {
1981  auto extractStridedMetadata =
1982  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
1983  if (!extractStridedMetadata)
1984  return failure();
1985 
1986  // Check if the reinterpret cast reconstructs a memref with the exact same
1987  // properties as the extract strided metadata.
1988  auto isReinterpretCastNoop = [&]() -> bool {
1989  // First, check that the strides are the same.
1990  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
1991  op.getConstifiedMixedStrides()))
1992  return false;
1993 
1994  // Second, check the sizes.
1995  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
1996  op.getConstifiedMixedSizes()))
1997  return false;
1998 
1999  // Finally, check the offset.
2000  assert(op.getMixedOffsets().size() == 1 &&
2001  "reinterpret_cast with more than one offset should have been "
2002  "rejected by the verifier");
2003  return extractStridedMetadata.getConstifiedMixedOffset() ==
2004  op.getConstifiedMixedOffset();
2005  };
2006 
2007  if (!isReinterpretCastNoop()) {
2008  // If the extract_strided_metadata / reinterpret_cast pair can't be
2009  // completely folded, then we could fold the input of the
2010  // extract_strided_metadata into the input of the reinterpret_cast
2011  // input. For some cases (e.g., static dimensions) the
2012  // the extract_strided_metadata is eliminated by dead code elimination.
2013  //
2014  // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2015  //
2016  // We can always fold the input of a extract_strided_metadata operator
2017  // to the input of a reinterpret_cast operator, because they point to
2018  // the same memory. Note that the reinterpret_cast does not use the
2019  // layout of its input memref, only its base memory pointer which is
2020  // the same as the base pointer returned by the extract_strided_metadata
2021  // operator and the base pointer of the extract_strided_metadata memref
2022  // input.
2023  rewriter.modifyOpInPlace(op, [&]() {
2024  op.getSourceMutable().assign(extractStridedMetadata.getSource());
2025  });
2026  return success();
2027  }
2028 
2029  // At this point, we know that the back and forth between extract strided
2030  // metadata and reinterpret cast is a noop. However, the final type of the
2031  // reinterpret cast may not be exactly the same as the original memref.
2032  // E.g., it could be changing a dimension from static to dynamic. Check that
2033  // here and add a cast if necessary.
2034  Type srcTy = extractStridedMetadata.getSource().getType();
2035  if (srcTy == op.getResult().getType())
2036  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2037  else
2038  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2039  extractStridedMetadata.getSource());
2040 
2041  return success();
2042  }
2043 };
2044 } // namespace
2045 
2046 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2047  MLIRContext *context) {
2048  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2049 }
2050 
2051 //===----------------------------------------------------------------------===//
2052 // Reassociative reshape ops
2053 //===----------------------------------------------------------------------===//
2054 
2055 void CollapseShapeOp::getAsmResultNames(
2056  function_ref<void(Value, StringRef)> setNameFn) {
2057  setNameFn(getResult(), "collapse_shape");
2058 }
2059 
2060 void ExpandShapeOp::getAsmResultNames(
2061  function_ref<void(Value, StringRef)> setNameFn) {
2062  setNameFn(getResult(), "expand_shape");
2063 }
2064 
2065 LogicalResult ExpandShapeOp::reifyResultShapes(
2066  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2067  reifiedResultShapes = {
2068  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2069  return success();
2070 }
2071 
2072 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2073 /// result and operand. Layout maps are verified separately.
2074 ///
2075 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2076 /// allowed in a reassocation group.
2077 static LogicalResult
2079  ArrayRef<int64_t> expandedShape,
2080  ArrayRef<ReassociationIndices> reassociation,
2081  bool allowMultipleDynamicDimsPerGroup) {
2082  // There must be one reassociation group per collapsed dimension.
2083  if (collapsedShape.size() != reassociation.size())
2084  return op->emitOpError("invalid number of reassociation groups: found ")
2085  << reassociation.size() << ", expected " << collapsedShape.size();
2086 
2087  // The next expected expanded dimension index (while iterating over
2088  // reassociation indices).
2089  int64_t nextDim = 0;
2090  for (const auto &it : llvm::enumerate(reassociation)) {
2091  ReassociationIndices group = it.value();
2092  int64_t collapsedDim = it.index();
2093 
2094  bool foundDynamic = false;
2095  for (int64_t expandedDim : group) {
2096  if (expandedDim != nextDim++)
2097  return op->emitOpError("reassociation indices must be contiguous");
2098 
2099  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2100  return op->emitOpError("reassociation index ")
2101  << expandedDim << " is out of bounds";
2102 
2103  // Check if there are multiple dynamic dims in a reassociation group.
2104  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2105  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2106  return op->emitOpError(
2107  "at most one dimension in a reassociation group may be dynamic");
2108  foundDynamic = true;
2109  }
2110  }
2111 
2112  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2113  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2114  return op->emitOpError("collapsed dim (")
2115  << collapsedDim
2116  << ") must be dynamic if and only if reassociation group is "
2117  "dynamic";
2118 
2119  // If all dims in the reassociation group are static, the size of the
2120  // collapsed dim can be verified.
2121  if (!foundDynamic) {
2122  int64_t groupSize = 1;
2123  for (int64_t expandedDim : group)
2124  groupSize *= expandedShape[expandedDim];
2125  if (groupSize != collapsedShape[collapsedDim])
2126  return op->emitOpError("collapsed dim size (")
2127  << collapsedShape[collapsedDim]
2128  << ") must equal reassociation group size (" << groupSize << ")";
2129  }
2130  }
2131 
2132  if (collapsedShape.empty()) {
2133  // Rank 0: All expanded dimensions must be 1.
2134  for (int64_t d : expandedShape)
2135  if (d != 1)
2136  return op->emitOpError(
2137  "rank 0 memrefs can only be extended/collapsed with/from ones");
2138  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2139  // Rank >= 1: Number of dimensions among all reassociation groups must match
2140  // the result memref rank.
2141  return op->emitOpError("expanded rank (")
2142  << expandedShape.size()
2143  << ") inconsistent with number of reassociation indices (" << nextDim
2144  << ")";
2145  }
2146 
2147  return success();
2148 }
2149 
2150 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2151  return getSymbolLessAffineMaps(getReassociationExprs());
2152 }
2153 
2154 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2156  getReassociationIndices());
2157 }
2158 
2159 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2160  return getSymbolLessAffineMaps(getReassociationExprs());
2161 }
2162 
2163 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2165  getReassociationIndices());
2166 }
2167 
2168 /// Compute the layout map after expanding a given source MemRef type with the
2169 /// specified reassociation indices.
2170 static FailureOr<StridedLayoutAttr>
2171 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2172  ArrayRef<ReassociationIndices> reassociation) {
2173  int64_t srcOffset;
2174  SmallVector<int64_t> srcStrides;
2175  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2176  return failure();
2177  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2178 
2179  // 1-1 mapping between srcStrides and reassociation packs.
2180  // Each srcStride starts with the given value and gets expanded according to
2181  // the proper entries in resultShape.
2182  // Example:
2183  // srcStrides = [10000, 1 , 100 ],
2184  // reassociations = [ [0], [1], [2, 3, 4]],
2185  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2186  // -> For the purpose of stride calculation, the useful sizes are:
2187  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2188  // resultStrides = [10000, 1, 600, 200, 100]
2189  // Note that a stride does not get expanded along the first entry of each
2190  // shape pack.
2191  SmallVector<int64_t> reverseResultStrides;
2192  reverseResultStrides.reserve(resultShape.size());
2193  unsigned shapeIndex = resultShape.size() - 1;
2194  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2195  ReassociationIndices reassoc = std::get<0>(it);
2196  int64_t currentStrideToExpand = std::get<1>(it);
2197  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2198  reverseResultStrides.push_back(currentStrideToExpand);
2199  currentStrideToExpand =
2200  (SaturatedInteger::wrap(currentStrideToExpand) *
2201  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2202  .asInteger();
2203  }
2204  }
2205  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2206  resultStrides.resize(resultShape.size(), 1);
2207  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2208 }
2209 
2210 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2211  MemRefType srcType, ArrayRef<int64_t> resultShape,
2212  ArrayRef<ReassociationIndices> reassociation) {
2213  if (srcType.getLayout().isIdentity()) {
2214  // If the source is contiguous (i.e., no layout map specified), so is the
2215  // result.
2216  MemRefLayoutAttrInterface layout;
2217  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2218  srcType.getMemorySpace());
2219  }
2220 
2221  // Source may not be contiguous. Compute the layout map.
2222  FailureOr<StridedLayoutAttr> computedLayout =
2223  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2224  if (failed(computedLayout))
2225  return failure();
2226  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2227  srcType.getMemorySpace());
2228 }
2229 
2230 FailureOr<SmallVector<OpFoldResult>>
2231 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2232  MemRefType expandedType,
2233  ArrayRef<ReassociationIndices> reassociation,
2234  ArrayRef<OpFoldResult> inputShape) {
2235  std::optional<SmallVector<OpFoldResult>> outputShape =
2236  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2237  inputShape);
2238  if (!outputShape)
2239  return failure();
2240  return *outputShape;
2241 }
2242 
2243 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2244  Type resultType, Value src,
2245  ArrayRef<ReassociationIndices> reassociation,
2246  ArrayRef<OpFoldResult> outputShape) {
2247  auto [staticOutputShape, dynamicOutputShape] =
2249  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2250  getReassociationIndicesAttribute(builder, reassociation),
2251  dynamicOutputShape, staticOutputShape);
2252 }
2253 
2254 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2255  Type resultType, Value src,
2256  ArrayRef<ReassociationIndices> reassociation) {
2257  SmallVector<OpFoldResult> inputShape =
2258  getMixedSizes(builder, result.location, src);
2259  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2260  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2261  builder, result.location, memrefResultTy, reassociation, inputShape);
2262  // Failure of this assertion usually indicates presence of multiple
2263  // dynamic dimensions in the same reassociation group.
2264  assert(succeeded(outputShape) && "unable to infer output shape");
2265  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2266 }
2267 
2268 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2269  ArrayRef<int64_t> resultShape, Value src,
2270  ArrayRef<ReassociationIndices> reassociation) {
2271  // Only ranked memref source values are supported.
2272  auto srcType = llvm::cast<MemRefType>(src.getType());
2273  FailureOr<MemRefType> resultType =
2274  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2275  // Failure of this assertion usually indicates a problem with the source
2276  // type, e.g., could not get strides/offset.
2277  assert(succeeded(resultType) && "could not compute layout");
2278  build(builder, result, *resultType, src, reassociation);
2279 }
2280 
2281 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2282  ArrayRef<int64_t> resultShape, Value src,
2283  ArrayRef<ReassociationIndices> reassociation,
2284  ArrayRef<OpFoldResult> outputShape) {
2285  // Only ranked memref source values are supported.
2286  auto srcType = llvm::cast<MemRefType>(src.getType());
2287  FailureOr<MemRefType> resultType =
2288  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2289  // Failure of this assertion usually indicates a problem with the source
2290  // type, e.g., could not get strides/offset.
2291  assert(succeeded(resultType) && "could not compute layout");
2292  build(builder, result, *resultType, src, reassociation, outputShape);
2293 }
2294 
2295 LogicalResult ExpandShapeOp::verify() {
2296  MemRefType srcType = getSrcType();
2297  MemRefType resultType = getResultType();
2298 
2299  if (srcType.getRank() > resultType.getRank()) {
2300  auto r0 = srcType.getRank();
2301  auto r1 = resultType.getRank();
2302  return emitOpError("has source rank ")
2303  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2304  << r0 << " > " << r1 << ").";
2305  }
2306 
2307  // Verify result shape.
2308  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2309  resultType.getShape(),
2310  getReassociationIndices(),
2311  /*allowMultipleDynamicDimsPerGroup=*/true)))
2312  return failure();
2313 
2314  // Compute expected result type (including layout map).
2315  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2316  srcType, resultType.getShape(), getReassociationIndices());
2317  if (failed(expectedResultType))
2318  return emitOpError("invalid source layout map");
2319 
2320  // Check actual result type.
2321  if (*expectedResultType != resultType)
2322  return emitOpError("expected expanded type to be ")
2323  << *expectedResultType << " but found " << resultType;
2324 
2325  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2326  return emitOpError("expected number of static shape bounds to be equal to "
2327  "the output rank (")
2328  << resultType.getRank() << ") but found "
2329  << getStaticOutputShape().size() << " inputs instead";
2330 
2331  if ((int64_t)getOutputShape().size() !=
2332  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2333  return emitOpError("mismatch in dynamic dims in output_shape and "
2334  "static_output_shape: static_output_shape has ")
2335  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2336  << " dynamic dims while output_shape has " << getOutputShape().size()
2337  << " values";
2338 
2339  // Verify if provided output shapes are in agreement with output type.
2340  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2341  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2342  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2343  if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2344  return emitOpError("invalid output shape provided at pos ") << pos;
2345  }
2346  }
2347 
2348  return success();
2349 }
2350 
2351 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2352  MLIRContext *context) {
2353  results.add<
2356 }
2357 
2358 /// Compute the layout map after collapsing a given source MemRef type with the
2359 /// specified reassociation indices.
2360 ///
2361 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2362 /// not possible to check this by inspecting a MemRefType in the general case.
2363 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2364 /// be valid (and thus accepted by this function) unless `strict = true`.
2365 static FailureOr<StridedLayoutAttr>
2366 computeCollapsedLayoutMap(MemRefType srcType,
2367  ArrayRef<ReassociationIndices> reassociation,
2368  bool strict = false) {
2369  int64_t srcOffset;
2370  SmallVector<int64_t> srcStrides;
2371  auto srcShape = srcType.getShape();
2372  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2373  return failure();
2374 
2375  // The result stride of a reassociation group is the stride of the last entry
2376  // of the reassociation. (TODO: Should be the minimum stride in the
2377  // reassociation because strides are not necessarily sorted. E.g., when using
2378  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2379  // strides are meaningless and could have any arbitrary value.
2380  SmallVector<int64_t> resultStrides;
2381  resultStrides.reserve(reassociation.size());
2382  for (const ReassociationIndices &reassoc : reassociation) {
2383  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2384  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2385  ref = ref.drop_back();
2386  if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2387  resultStrides.push_back(srcStrides[ref.back()]);
2388  } else {
2389  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2390  // the corresponding stride may have to be skipped. (See above comment.)
2391  // Therefore, the result stride cannot be statically determined and must
2392  // be dynamic.
2393  resultStrides.push_back(ShapedType::kDynamic);
2394  }
2395  }
2396 
2397  // Validate that each reassociation group is contiguous.
2398  unsigned resultStrideIndex = resultStrides.size() - 1;
2399  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2400  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2401  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2402  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2403  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2404 
2405  // Both source and result stride must have the same static value. In that
2406  // case, we can be sure, that the dimensions are collapsible (because they
2407  // are contiguous).
2408  // If `strict = false` (default during op verification), we accept cases
2409  // where one or both strides are dynamic. This is best effort: We reject
2410  // ops where obviously non-contiguous dims are collapsed, but accept ops
2411  // where we cannot be sure statically. Such ops may fail at runtime. See
2412  // the op documentation for details.
2413  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2414  if (strict && (stride.saturated || srcStride.saturated))
2415  return failure();
2416 
2417  // Dimensions of size 1 should be skipped, because their strides are
2418  // meaningless and could have any arbitrary value.
2419  if (srcShape[idx - 1] == 1)
2420  continue;
2421 
2422  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2423  return failure();
2424  }
2425  }
2426  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2427 }
2428 
2429 bool CollapseShapeOp::isGuaranteedCollapsible(
2430  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2431  // MemRefs with identity layout are always collapsible.
2432  if (srcType.getLayout().isIdentity())
2433  return true;
2434 
2435  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2436  /*strict=*/true));
2437 }
2438 
2439 MemRefType CollapseShapeOp::computeCollapsedType(
2440  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2441  SmallVector<int64_t> resultShape;
2442  resultShape.reserve(reassociation.size());
2443  for (const ReassociationIndices &group : reassociation) {
2444  auto groupSize = SaturatedInteger::wrap(1);
2445  for (int64_t srcDim : group)
2446  groupSize =
2447  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2448  resultShape.push_back(groupSize.asInteger());
2449  }
2450 
2451  if (srcType.getLayout().isIdentity()) {
2452  // If the source is contiguous (i.e., no layout map specified), so is the
2453  // result.
2454  MemRefLayoutAttrInterface layout;
2455  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2456  srcType.getMemorySpace());
2457  }
2458 
2459  // Source may not be fully contiguous. Compute the layout map.
2460  // Note: Dimensions that are collapsed into a single dim are assumed to be
2461  // contiguous.
2462  FailureOr<StridedLayoutAttr> computedLayout =
2463  computeCollapsedLayoutMap(srcType, reassociation);
2464  assert(succeeded(computedLayout) &&
2465  "invalid source layout map or collapsing non-contiguous dims");
2466  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2467  srcType.getMemorySpace());
2468 }
2469 
2470 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2471  ArrayRef<ReassociationIndices> reassociation,
2472  ArrayRef<NamedAttribute> attrs) {
2473  auto srcType = llvm::cast<MemRefType>(src.getType());
2474  MemRefType resultType =
2475  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2477  getReassociationIndicesAttribute(b, reassociation));
2478  build(b, result, resultType, src, attrs);
2479 }
2480 
2481 LogicalResult CollapseShapeOp::verify() {
2482  MemRefType srcType = getSrcType();
2483  MemRefType resultType = getResultType();
2484 
2485  if (srcType.getRank() < resultType.getRank()) {
2486  auto r0 = srcType.getRank();
2487  auto r1 = resultType.getRank();
2488  return emitOpError("has source rank ")
2489  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2490  << r0 << " < " << r1 << ").";
2491  }
2492 
2493  // Verify result shape.
2494  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2495  srcType.getShape(), getReassociationIndices(),
2496  /*allowMultipleDynamicDimsPerGroup=*/true)))
2497  return failure();
2498 
2499  // Compute expected result type (including layout map).
2500  MemRefType expectedResultType;
2501  if (srcType.getLayout().isIdentity()) {
2502  // If the source is contiguous (i.e., no layout map specified), so is the
2503  // result.
2504  MemRefLayoutAttrInterface layout;
2505  expectedResultType =
2506  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2507  srcType.getMemorySpace());
2508  } else {
2509  // Source may not be fully contiguous. Compute the layout map.
2510  // Note: Dimensions that are collapsed into a single dim are assumed to be
2511  // contiguous.
2512  FailureOr<StridedLayoutAttr> computedLayout =
2513  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2514  if (failed(computedLayout))
2515  return emitOpError(
2516  "invalid source layout map or collapsing non-contiguous dims");
2517  expectedResultType =
2518  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2519  *computedLayout, srcType.getMemorySpace());
2520  }
2521 
2522  if (expectedResultType != resultType)
2523  return emitOpError("expected collapsed type to be ")
2524  << expectedResultType << " but found " << resultType;
2525 
2526  return success();
2527 }
2528 
2530  : public OpRewritePattern<CollapseShapeOp> {
2531 public:
2533 
2534  LogicalResult matchAndRewrite(CollapseShapeOp op,
2535  PatternRewriter &rewriter) const override {
2536  auto cast = op.getOperand().getDefiningOp<CastOp>();
2537  if (!cast)
2538  return failure();
2539 
2540  if (!CastOp::canFoldIntoConsumerOp(cast))
2541  return failure();
2542 
2543  Type newResultType = CollapseShapeOp::computeCollapsedType(
2544  llvm::cast<MemRefType>(cast.getOperand().getType()),
2545  op.getReassociationIndices());
2546 
2547  if (newResultType == op.getResultType()) {
2548  rewriter.modifyOpInPlace(
2549  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2550  } else {
2551  Value newOp = rewriter.create<CollapseShapeOp>(
2552  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2553  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2554  }
2555  return success();
2556  }
2557 };
2558 
2559 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2560  MLIRContext *context) {
2561  results.add<
2563  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2564  memref::DimOp, MemRefType>,
2566 }
2567 
2568 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2569  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2570  adaptor.getOperands());
2571 }
2572 
2573 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2574  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2575  adaptor.getOperands());
2576 }
2577 
2578 //===----------------------------------------------------------------------===//
2579 // ReshapeOp
2580 //===----------------------------------------------------------------------===//
2581 
2582 void ReshapeOp::getAsmResultNames(
2583  function_ref<void(Value, StringRef)> setNameFn) {
2584  setNameFn(getResult(), "reshape");
2585 }
2586 
2587 LogicalResult ReshapeOp::verify() {
2588  Type operandType = getSource().getType();
2589  Type resultType = getResult().getType();
2590 
2591  Type operandElementType =
2592  llvm::cast<ShapedType>(operandType).getElementType();
2593  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2594  if (operandElementType != resultElementType)
2595  return emitOpError("element types of source and destination memref "
2596  "types should be the same");
2597 
2598  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2599  if (!operandMemRefType.getLayout().isIdentity())
2600  return emitOpError("source memref type should have identity affine map");
2601 
2602  int64_t shapeSize =
2603  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2604  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2605  if (resultMemRefType) {
2606  if (!resultMemRefType.getLayout().isIdentity())
2607  return emitOpError("result memref type should have identity affine map");
2608  if (shapeSize == ShapedType::kDynamic)
2609  return emitOpError("cannot use shape operand with dynamic length to "
2610  "reshape to statically-ranked memref type");
2611  if (shapeSize != resultMemRefType.getRank())
2612  return emitOpError(
2613  "length of shape operand differs from the result's memref rank");
2614  }
2615  return success();
2616 }
2617 
2618 //===----------------------------------------------------------------------===//
2619 // StoreOp
2620 //===----------------------------------------------------------------------===//
2621 
2622 LogicalResult StoreOp::verify() {
2623  if (getNumOperands() != 2 + getMemRefType().getRank())
2624  return emitOpError("store index operand count not equal to memref rank");
2625 
2626  return success();
2627 }
2628 
2629 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2630  SmallVectorImpl<OpFoldResult> &results) {
2631  /// store(memrefcast) -> store
2632  return foldMemRefCast(*this, getValueToStore());
2633 }
2634 
2635 //===----------------------------------------------------------------------===//
2636 // SubViewOp
2637 //===----------------------------------------------------------------------===//
2638 
2639 void SubViewOp::getAsmResultNames(
2640  function_ref<void(Value, StringRef)> setNameFn) {
2641  setNameFn(getResult(), "subview");
2642 }
2643 
2644 /// A subview result type can be fully inferred from the source type and the
2645 /// static representation of offsets, sizes and strides. Special sentinels
2646 /// encode the dynamic case.
2647 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2648  ArrayRef<int64_t> staticOffsets,
2649  ArrayRef<int64_t> staticSizes,
2650  ArrayRef<int64_t> staticStrides) {
2651  unsigned rank = sourceMemRefType.getRank();
2652  (void)rank;
2653  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2654  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2655  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2656 
2657  // Extract source offset and strides.
2658  auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2659 
2660  // Compute target offset whose value is:
2661  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2662  int64_t targetOffset = sourceOffset;
2663  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2664  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2665  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2666  SaturatedInteger::wrap(staticOffset) *
2667  SaturatedInteger::wrap(sourceStride))
2668  .asInteger();
2669  }
2670 
2671  // Compute target stride whose value is:
2672  // `sourceStrides_i * staticStrides_i`.
2673  SmallVector<int64_t, 4> targetStrides;
2674  targetStrides.reserve(staticOffsets.size());
2675  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2676  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2677  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2678  SaturatedInteger::wrap(staticStride))
2679  .asInteger());
2680  }
2681 
2682  // The type is now known.
2683  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2684  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2685  targetOffset, targetStrides),
2686  sourceMemRefType.getMemorySpace());
2687 }
2688 
2689 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2690  ArrayRef<OpFoldResult> offsets,
2691  ArrayRef<OpFoldResult> sizes,
2692  ArrayRef<OpFoldResult> strides) {
2693  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2694  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2695  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2696  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2697  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2698  if (!hasValidSizesOffsets(staticOffsets))
2699  return {};
2700  if (!hasValidSizesOffsets(staticSizes))
2701  return {};
2702  if (!hasValidStrides(staticStrides))
2703  return {};
2704  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2705  staticSizes, staticStrides);
2706 }
2707 
2708 MemRefType SubViewOp::inferRankReducedResultType(
2709  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2710  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2711  ArrayRef<int64_t> strides) {
2712  MemRefType inferredType =
2713  inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2714  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2715  "expected ");
2716  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2717  return inferredType;
2718 
2719  // Compute which dimensions are dropped.
2720  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2721  computeRankReductionMask(inferredType.getShape(), resultShape);
2722  assert(dimsToProject.has_value() && "invalid rank reduction");
2723 
2724  // Compute the layout and result type.
2725  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2726  SmallVector<int64_t> rankReducedStrides;
2727  rankReducedStrides.reserve(resultShape.size());
2728  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2729  if (!dimsToProject->contains(idx))
2730  rankReducedStrides.push_back(value);
2731  }
2732  return MemRefType::get(resultShape, inferredType.getElementType(),
2733  StridedLayoutAttr::get(inferredLayout.getContext(),
2734  inferredLayout.getOffset(),
2735  rankReducedStrides),
2736  inferredType.getMemorySpace());
2737 }
2738 
2739 MemRefType SubViewOp::inferRankReducedResultType(
2740  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2742  ArrayRef<OpFoldResult> strides) {
2743  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2744  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2745  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2746  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2747  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2748  return SubViewOp::inferRankReducedResultType(
2749  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2750  staticStrides);
2751 }
2752 
2753 // Build a SubViewOp with mixed static and dynamic entries and custom result
2754 // type. If the type passed is nullptr, it is inferred.
2755 void SubViewOp::build(OpBuilder &b, OperationState &result,
2756  MemRefType resultType, Value source,
2757  ArrayRef<OpFoldResult> offsets,
2758  ArrayRef<OpFoldResult> sizes,
2759  ArrayRef<OpFoldResult> strides,
2760  ArrayRef<NamedAttribute> attrs) {
2761  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2762  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2763  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2764  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2765  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2766  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2767  // Structuring implementation this way avoids duplication between builders.
2768  if (!resultType) {
2769  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2770  staticSizes, staticStrides);
2771  }
2772  result.addAttributes(attrs);
2773  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2774  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2775  b.getDenseI64ArrayAttr(staticSizes),
2776  b.getDenseI64ArrayAttr(staticStrides));
2777 }
2778 
2779 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2780 // type.
2781 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2782  ArrayRef<OpFoldResult> offsets,
2783  ArrayRef<OpFoldResult> sizes,
2784  ArrayRef<OpFoldResult> strides,
2785  ArrayRef<NamedAttribute> attrs) {
2786  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2787 }
2788 
2789 // Build a SubViewOp with static entries and inferred result type.
2790 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2791  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2792  ArrayRef<int64_t> strides,
2793  ArrayRef<NamedAttribute> attrs) {
2794  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2795  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2796  return b.getI64IntegerAttr(v);
2797  }));
2798  SmallVector<OpFoldResult> sizeValues =
2799  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2800  return b.getI64IntegerAttr(v);
2801  }));
2802  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2803  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2804  return b.getI64IntegerAttr(v);
2805  }));
2806  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2807 }
2808 
2809 // Build a SubViewOp with dynamic entries and custom result type. If the
2810 // type passed is nullptr, it is inferred.
2811 void SubViewOp::build(OpBuilder &b, OperationState &result,
2812  MemRefType resultType, Value source,
2813  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2814  ArrayRef<int64_t> strides,
2815  ArrayRef<NamedAttribute> attrs) {
2816  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2817  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2818  return b.getI64IntegerAttr(v);
2819  }));
2820  SmallVector<OpFoldResult> sizeValues =
2821  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2822  return b.getI64IntegerAttr(v);
2823  }));
2824  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2825  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2826  return b.getI64IntegerAttr(v);
2827  }));
2828  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2829  attrs);
2830 }
2831 
2832 // Build a SubViewOp with dynamic entries and custom result type. If the type
2833 // passed is nullptr, it is inferred.
2834 void SubViewOp::build(OpBuilder &b, OperationState &result,
2835  MemRefType resultType, Value source, ValueRange offsets,
2836  ValueRange sizes, ValueRange strides,
2837  ArrayRef<NamedAttribute> attrs) {
2838  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2839  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2840  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2841  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2842  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2843  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2844  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2845 }
2846 
2847 // Build a SubViewOp with dynamic entries and inferred result type.
2848 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2849  ValueRange offsets, ValueRange sizes, ValueRange strides,
2850  ArrayRef<NamedAttribute> attrs) {
2851  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2852 }
2853 
2854 /// For ViewLikeOpInterface.
2855 Value SubViewOp::getViewSource() { return getSource(); }
2856 
2857 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2858 /// static value).
2859 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2860  int64_t t1Offset, t2Offset;
2861  SmallVector<int64_t> t1Strides, t2Strides;
2862  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2863  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2864  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2865 }
2866 
2867 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2868 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2869 /// marked as dropped in `droppedDims`.
2870 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2871  const llvm::SmallBitVector &droppedDims) {
2872  assert(size_t(t1.getRank()) == droppedDims.size() &&
2873  "incorrect number of bits");
2874  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2875  "incorrect number of dropped dims");
2876  int64_t t1Offset, t2Offset;
2877  SmallVector<int64_t> t1Strides, t2Strides;
2878  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2879  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2880  if (failed(res1) || failed(res2))
2881  return false;
2882  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2883  if (droppedDims[i])
2884  continue;
2885  if (t1Strides[i] != t2Strides[j])
2886  return false;
2887  ++j;
2888  }
2889  return true;
2890 }
2891 
2893  SubViewOp op, Type expectedType) {
2894  auto memrefType = llvm::cast<ShapedType>(expectedType);
2895  switch (result) {
2897  return success();
2899  return op->emitError("expected result rank to be smaller or equal to ")
2900  << "the source rank, but got " << op.getType();
2902  return op->emitError("expected result type to be ")
2903  << expectedType
2904  << " or a rank-reduced version. (mismatch of result sizes), but got "
2905  << op.getType();
2907  return op->emitError("expected result element type to be ")
2908  << memrefType.getElementType() << ", but got " << op.getType();
2910  return op->emitError(
2911  "expected result and source memory spaces to match, but got ")
2912  << op.getType();
2914  return op->emitError("expected result type to be ")
2915  << expectedType
2916  << " or a rank-reduced version. (mismatch of result layout), but "
2917  "got "
2918  << op.getType();
2919  }
2920  llvm_unreachable("unexpected subview verification result");
2921 }
2922 
2923 /// Verifier for SubViewOp.
2924 LogicalResult SubViewOp::verify() {
2925  MemRefType baseType = getSourceType();
2926  MemRefType subViewType = getType();
2927  ArrayRef<int64_t> staticOffsets = getStaticOffsets();
2928  ArrayRef<int64_t> staticSizes = getStaticSizes();
2929  ArrayRef<int64_t> staticStrides = getStaticStrides();
2930 
2931  // The base memref and the view memref should be in the same memory space.
2932  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2933  return emitError("different memory spaces specified for base memref "
2934  "type ")
2935  << baseType << " and subview memref type " << subViewType;
2936 
2937  // Verify that the base memref type has a strided layout map.
2938  if (!baseType.isStrided())
2939  return emitError("base type ") << baseType << " is not strided";
2940 
2941  // Compute the expected result type, assuming that there are no rank
2942  // reductions.
2943  MemRefType expectedType = SubViewOp::inferResultType(
2944  baseType, staticOffsets, staticSizes, staticStrides);
2945 
2946  // Verify all properties of a shaped type: rank, element type and dimension
2947  // sizes. This takes into account potential rank reductions.
2948  auto shapedTypeVerification = isRankReducedType(
2949  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2950  if (shapedTypeVerification != SliceVerificationResult::Success)
2951  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2952 
2953  // Make sure that the memory space did not change.
2954  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2956  *this, expectedType);
2957 
2958  // Verify the offset of the layout map.
2959  if (!haveCompatibleOffsets(expectedType, subViewType))
2961  *this, expectedType);
2962 
2963  // The only thing that's left to verify now are the strides. First, compute
2964  // the unused dimensions due to rank reductions. We have to look at sizes and
2965  // strides to decide which dimensions were dropped. This function also
2966  // partially verifies strides in case of rank reductions.
2967  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
2968  getMixedSizes());
2969  if (failed(unusedDims))
2971  *this, expectedType);
2972 
2973  // Strides must match.
2974  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
2976  *this, expectedType);
2977 
2978  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2979  // to the base memref.
2980  SliceBoundsVerificationResult boundsResult =
2981  verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
2982  staticStrides, /*generateErrorMessage=*/true);
2983  if (!boundsResult.isValid)
2984  return getOperation()->emitError(boundsResult.errorMessage);
2985 
2986  return success();
2987 }
2988 
2989 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2990  return os << "range " << range.offset << ":" << range.size << ":"
2991  << range.stride;
2992 }
2993 
2994 /// Return the list of Range (i.e. offset, size, stride). Each Range
2995 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2996 /// with `b` at location `loc`.
2997 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2998  OpBuilder &b, Location loc) {
2999  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3000  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3001  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3003  unsigned rank = ranks[0];
3004  res.reserve(rank);
3005  for (unsigned idx = 0; idx < rank; ++idx) {
3006  Value offset =
3007  op.isDynamicOffset(idx)
3008  ? op.getDynamicOffset(idx)
3009  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
3010  Value size =
3011  op.isDynamicSize(idx)
3012  ? op.getDynamicSize(idx)
3013  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
3014  Value stride =
3015  op.isDynamicStride(idx)
3016  ? op.getDynamicStride(idx)
3017  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
3018  res.emplace_back(Range{offset, size, stride});
3019  }
3020  return res;
3021 }
3022 
3023 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3024 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3025 /// the rank of the inferred result type if `currentResultType` is lower rank
3026 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3027 /// together with the result type. In this case, it is important to compute
3028 /// the dropped dimensions using `currentSourceType` whose strides align with
3029 /// `currentResultType`.
3031  MemRefType currentResultType, MemRefType currentSourceType,
3032  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3033  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3034  MemRefType nonRankReducedType = SubViewOp::inferResultType(
3035  sourceType, mixedOffsets, mixedSizes, mixedStrides);
3036  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3037  currentSourceType, currentResultType, mixedSizes);
3038  if (failed(unusedDims))
3039  return nullptr;
3040 
3041  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3042  SmallVector<int64_t> shape, strides;
3043  unsigned numDimsAfterReduction =
3044  nonRankReducedType.getRank() - unusedDims->count();
3045  shape.reserve(numDimsAfterReduction);
3046  strides.reserve(numDimsAfterReduction);
3047  for (const auto &[idx, size, stride] :
3048  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3049  nonRankReducedType.getShape(), layout.getStrides())) {
3050  if (unusedDims->test(idx))
3051  continue;
3052  shape.push_back(size);
3053  strides.push_back(stride);
3054  }
3055 
3056  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3057  StridedLayoutAttr::get(sourceType.getContext(),
3058  layout.getOffset(), strides),
3059  nonRankReducedType.getMemorySpace());
3060 }
3061 
3063  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3064  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3065  unsigned rank = memrefType.getRank();
3066  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3067  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3068  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3069  MemRefType targetType = SubViewOp::inferRankReducedResultType(
3070  targetShape, memrefType, offsets, sizes, strides);
3071  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3072  sizes, strides);
3073 }
3074 
3075 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3076  Value value,
3077  ArrayRef<int64_t> desiredShape) {
3078  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3079  assert(sourceMemrefType && "not a ranked memref type");
3080  auto sourceShape = sourceMemrefType.getShape();
3081  if (sourceShape.equals(desiredShape))
3082  return value;
3083  auto maybeRankReductionMask =
3084  mlir::computeRankReductionMask(sourceShape, desiredShape);
3085  if (!maybeRankReductionMask)
3086  return failure();
3087  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3088 }
3089 
3090 /// Helper method to check if a `subview` operation is trivially a no-op. This
3091 /// is the case if the all offsets are zero, all strides are 1, and the source
3092 /// shape is same as the size of the subview. In such cases, the subview can
3093 /// be folded into its source.
3094 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3095  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3096  return false;
3097 
3098  auto mixedOffsets = subViewOp.getMixedOffsets();
3099  auto mixedSizes = subViewOp.getMixedSizes();
3100  auto mixedStrides = subViewOp.getMixedStrides();
3101 
3102  // Check offsets are zero.
3103  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3104  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3105  return !intValue || intValue.value() != 0;
3106  }))
3107  return false;
3108 
3109  // Check strides are one.
3110  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3111  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3112  return !intValue || intValue.value() != 1;
3113  }))
3114  return false;
3115 
3116  // Check all size values are static and matches the (static) source shape.
3117  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3118  for (const auto &size : llvm::enumerate(mixedSizes)) {
3119  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3120  if (!intValue || *intValue != sourceShape[size.index()])
3121  return false;
3122  }
3123  // All conditions met. The `SubViewOp` is foldable as a no-op.
3124  return true;
3125 }
3126 
3127 namespace {
3128 /// Pattern to rewrite a subview op with MemRefCast arguments.
3129 /// This essentially pushes memref.cast past its consuming subview when
3130 /// `canFoldIntoConsumerOp` is true.
3131 ///
3132 /// Example:
3133 /// ```
3134 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3135 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3136 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3137 /// ```
3138 /// is rewritten into:
3139 /// ```
3140 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3141 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3142 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3143 /// ```
3144 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3145 public:
3147 
3148  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3149  PatternRewriter &rewriter) const override {
3150  // Any constant operand, just return to let SubViewOpConstantFolder kick
3151  // in.
3152  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3153  return matchPattern(operand, matchConstantIndex());
3154  }))
3155  return failure();
3156 
3157  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3158  if (!castOp)
3159  return failure();
3160 
3161  if (!CastOp::canFoldIntoConsumerOp(castOp))
3162  return failure();
3163 
3164  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3165  // the MemRefCastOp source operand type to infer the result type and the
3166  // current SubViewOp source operand type to compute the dropped dimensions
3167  // if the operation is rank-reducing.
3168  auto resultType = getCanonicalSubViewResultType(
3169  subViewOp.getType(), subViewOp.getSourceType(),
3170  llvm::cast<MemRefType>(castOp.getSource().getType()),
3171  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3172  subViewOp.getMixedStrides());
3173  if (!resultType)
3174  return failure();
3175 
3176  Value newSubView = rewriter.create<SubViewOp>(
3177  subViewOp.getLoc(), resultType, castOp.getSource(),
3178  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3179  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3180  subViewOp.getStaticStrides());
3181  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3182  newSubView);
3183  return success();
3184  }
3185 };
3186 
3187 /// Canonicalize subview ops that are no-ops. When the source shape is not
3188 /// same as a result shape due to use of `affine_map`.
3189 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3190 public:
3192 
3193  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3194  PatternRewriter &rewriter) const override {
3195  if (!isTrivialSubViewOp(subViewOp))
3196  return failure();
3197  if (subViewOp.getSourceType() == subViewOp.getType()) {
3198  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3199  return success();
3200  }
3201  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3202  subViewOp.getSource());
3203  return success();
3204  }
3205 };
3206 } // namespace
3207 
3208 /// Return the canonical type of the result of a subview.
3210  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3211  ArrayRef<OpFoldResult> mixedSizes,
3212  ArrayRef<OpFoldResult> mixedStrides) {
3213  // Infer a memref type without taking into account any rank reductions.
3214  MemRefType resTy = SubViewOp::inferResultType(
3215  op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3216  if (!resTy)
3217  return {};
3218  MemRefType nonReducedType = resTy;
3219 
3220  // Directly return the non-rank reduced type if there are no dropped dims.
3221  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3222  if (droppedDims.none())
3223  return nonReducedType;
3224 
3225  // Take the strides and offset from the non-rank reduced type.
3226  auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3227 
3228  // Drop dims from shape and strides.
3229  SmallVector<int64_t> targetShape;
3230  SmallVector<int64_t> targetStrides;
3231  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3232  if (droppedDims.test(i))
3233  continue;
3234  targetStrides.push_back(nonReducedStrides[i]);
3235  targetShape.push_back(nonReducedType.getDimSize(i));
3236  }
3237 
3238  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3239  StridedLayoutAttr::get(nonReducedType.getContext(),
3240  offset, targetStrides),
3241  nonReducedType.getMemorySpace());
3242  }
3243 };
3244 
3245 /// A canonicalizer wrapper to replace SubViewOps.
3247  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3248  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3249  }
3250 };
3251 
3252 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3253  MLIRContext *context) {
3254  results
3257  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3258 }
3259 
3260 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3261  MemRefType sourceMemrefType = getSource().getType();
3262  MemRefType resultMemrefType = getResult().getType();
3263  auto resultLayout =
3264  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3265 
3266  if (resultMemrefType == sourceMemrefType &&
3267  resultMemrefType.hasStaticShape() &&
3268  (!resultLayout || resultLayout.hasStaticLayout())) {
3269  return getViewSource();
3270  }
3271 
3272  // Fold subview(subview(x)), where both subviews have the same size and the
3273  // second subview's offsets are all zero. (I.e., the second subview is a
3274  // no-op.)
3275  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3276  auto srcSizes = srcSubview.getMixedSizes();
3277  auto sizes = getMixedSizes();
3278  auto offsets = getMixedOffsets();
3279  bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3280  auto strides = getMixedStrides();
3281  bool allStridesOne = llvm::all_of(strides, isOneInteger);
3282  bool allSizesSame = llvm::equal(sizes, srcSizes);
3283  if (allOffsetsZero && allStridesOne && allSizesSame &&
3284  resultMemrefType == sourceMemrefType)
3285  return getViewSource();
3286  }
3287 
3288  return {};
3289 }
3290 
3291 //===----------------------------------------------------------------------===//
3292 // TransposeOp
3293 //===----------------------------------------------------------------------===//
3294 
3295 void TransposeOp::getAsmResultNames(
3296  function_ref<void(Value, StringRef)> setNameFn) {
3297  setNameFn(getResult(), "transpose");
3298 }
3299 
3300 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3301 static MemRefType inferTransposeResultType(MemRefType memRefType,
3302  AffineMap permutationMap) {
3303  auto originalSizes = memRefType.getShape();
3304  auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3305  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3306 
3307  // Compute permuted sizes and strides.
3308  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3309  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3310 
3311  return MemRefType::Builder(memRefType)
3312  .setShape(sizes)
3313  .setLayout(
3314  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3315 }
3316 
3317 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3318  AffineMapAttr permutation,
3319  ArrayRef<NamedAttribute> attrs) {
3320  auto permutationMap = permutation.getValue();
3321  assert(permutationMap);
3322 
3323  auto memRefType = llvm::cast<MemRefType>(in.getType());
3324  // Compute result type.
3325  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3326 
3327  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3328  build(b, result, resultType, in, attrs);
3329 }
3330 
3331 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3333  p << " " << getIn() << " " << getPermutation();
3334  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3335  p << " : " << getIn().getType() << " to " << getType();
3336 }
3337 
3338 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3340  AffineMap permutation;
3341  MemRefType srcType, dstType;
3342  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3343  parser.parseOptionalAttrDict(result.attributes) ||
3344  parser.parseColonType(srcType) ||
3345  parser.resolveOperand(in, srcType, result.operands) ||
3346  parser.parseKeywordType("to", dstType) ||
3347  parser.addTypeToList(dstType, result.types))
3348  return failure();
3349 
3350  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3351  AffineMapAttr::get(permutation));
3352  return success();
3353 }
3354 
3355 LogicalResult TransposeOp::verify() {
3356  if (!getPermutation().isPermutation())
3357  return emitOpError("expected a permutation map");
3358  if (getPermutation().getNumDims() != getIn().getType().getRank())
3359  return emitOpError("expected a permutation map of same rank as the input");
3360 
3361  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3362  auto resultType = llvm::cast<MemRefType>(getType());
3363  auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3364  .canonicalizeStridedLayout();
3365 
3366  if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3367  return emitOpError("result type ")
3368  << resultType
3369  << " is not equivalent to the canonical transposed input type "
3370  << canonicalResultType;
3371  return success();
3372 }
3373 
3374 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3375  // First check for identity permutation, we can fold it away if input and
3376  // result types are identical already.
3377  if (getPermutation().isIdentity() && getType() == getIn().getType())
3378  return getIn();
3379  // Fold two consecutive memref.transpose Ops into one by composing their
3380  // permutation maps.
3381  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3382  AffineMap composedPermutation =
3383  getPermutation().compose(otherTransposeOp.getPermutation());
3384  getInMutable().assign(otherTransposeOp.getIn());
3385  setPermutation(composedPermutation);
3386  return getResult();
3387  }
3388  return {};
3389 }
3390 
3391 //===----------------------------------------------------------------------===//
3392 // ViewOp
3393 //===----------------------------------------------------------------------===//
3394 
3395 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3396  setNameFn(getResult(), "view");
3397 }
3398 
3399 LogicalResult ViewOp::verify() {
3400  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3401  auto viewType = getType();
3402 
3403  // The base memref should have identity layout map (or none).
3404  if (!baseType.getLayout().isIdentity())
3405  return emitError("unsupported map for base memref type ") << baseType;
3406 
3407  // The result memref should have identity layout map (or none).
3408  if (!viewType.getLayout().isIdentity())
3409  return emitError("unsupported map for result memref type ") << viewType;
3410 
3411  // The base memref and the view memref should be in the same memory space.
3412  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3413  return emitError("different memory spaces specified for base memref "
3414  "type ")
3415  << baseType << " and view memref type " << viewType;
3416 
3417  // Verify that we have the correct number of sizes for the result type.
3418  unsigned numDynamicDims = viewType.getNumDynamicDims();
3419  if (getSizes().size() != numDynamicDims)
3420  return emitError("incorrect number of size operands for type ") << viewType;
3421 
3422  return success();
3423 }
3424 
3425 Value ViewOp::getViewSource() { return getSource(); }
3426 
3427 OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3428  MemRefType sourceMemrefType = getSource().getType();
3429  MemRefType resultMemrefType = getResult().getType();
3430 
3431  if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3432  return getViewSource();
3433 
3434  return {};
3435 }
3436 
3437 namespace {
3438 
3439 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3441 
3442  LogicalResult matchAndRewrite(ViewOp viewOp,
3443  PatternRewriter &rewriter) const override {
3444  // Return if none of the operands are constants.
3445  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3446  return matchPattern(operand, matchConstantIndex());
3447  }))
3448  return failure();
3449 
3450  // Get result memref type.
3451  auto memrefType = viewOp.getType();
3452 
3453  // Get offset from old memref view type 'memRefType'.
3454  int64_t oldOffset;
3455  SmallVector<int64_t, 4> oldStrides;
3456  if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3457  return failure();
3458  assert(oldOffset == 0 && "Expected 0 offset");
3459 
3460  SmallVector<Value, 4> newOperands;
3461 
3462  // Offset cannot be folded into result type.
3463 
3464  // Fold any dynamic dim operands which are produced by a constant.
3465  SmallVector<int64_t, 4> newShapeConstants;
3466  newShapeConstants.reserve(memrefType.getRank());
3467 
3468  unsigned dynamicDimPos = 0;
3469  unsigned rank = memrefType.getRank();
3470  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3471  int64_t dimSize = memrefType.getDimSize(dim);
3472  // If this is already static dimension, keep it.
3473  if (ShapedType::isStatic(dimSize)) {
3474  newShapeConstants.push_back(dimSize);
3475  continue;
3476  }
3477  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3478  if (auto constantIndexOp =
3479  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3480  // Dynamic shape dimension will be folded.
3481  newShapeConstants.push_back(constantIndexOp.value());
3482  } else {
3483  // Dynamic shape dimension not folded; copy operand from old memref.
3484  newShapeConstants.push_back(dimSize);
3485  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3486  }
3487  dynamicDimPos++;
3488  }
3489 
3490  // Create new memref type with constant folded dims.
3491  MemRefType newMemRefType =
3492  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3493  // Nothing new, don't fold.
3494  if (newMemRefType == memrefType)
3495  return failure();
3496 
3497  // Create new ViewOp.
3498  auto newViewOp = rewriter.create<ViewOp>(
3499  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3500  viewOp.getByteShift(), newOperands);
3501  // Insert a cast so we have the same type as the old memref type.
3502  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3503  return success();
3504  }
3505 };
3506 
3507 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3509 
3510  LogicalResult matchAndRewrite(ViewOp viewOp,
3511  PatternRewriter &rewriter) const override {
3512  Value memrefOperand = viewOp.getOperand(0);
3513  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3514  if (!memrefCastOp)
3515  return failure();
3516  Value allocOperand = memrefCastOp.getOperand();
3517  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3518  if (!allocOp)
3519  return failure();
3520  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3521  viewOp.getByteShift(),
3522  viewOp.getSizes());
3523  return success();
3524  }
3525 };
3526 
3527 } // namespace
3528 
3529 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3530  MLIRContext *context) {
3531  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3532 }
3533 
3534 //===----------------------------------------------------------------------===//
3535 // AtomicRMWOp
3536 //===----------------------------------------------------------------------===//
3537 
3538 LogicalResult AtomicRMWOp::verify() {
3539  if (getMemRefType().getRank() != getNumOperands() - 2)
3540  return emitOpError(
3541  "expects the number of subscripts to be equal to memref rank");
3542  switch (getKind()) {
3543  case arith::AtomicRMWKind::addf:
3544  case arith::AtomicRMWKind::maximumf:
3545  case arith::AtomicRMWKind::minimumf:
3546  case arith::AtomicRMWKind::mulf:
3547  if (!llvm::isa<FloatType>(getValue().getType()))
3548  return emitOpError() << "with kind '"
3549  << arith::stringifyAtomicRMWKind(getKind())
3550  << "' expects a floating-point type";
3551  break;
3552  case arith::AtomicRMWKind::addi:
3553  case arith::AtomicRMWKind::maxs:
3554  case arith::AtomicRMWKind::maxu:
3555  case arith::AtomicRMWKind::mins:
3556  case arith::AtomicRMWKind::minu:
3557  case arith::AtomicRMWKind::muli:
3558  case arith::AtomicRMWKind::ori:
3559  case arith::AtomicRMWKind::andi:
3560  if (!llvm::isa<IntegerType>(getValue().getType()))
3561  return emitOpError() << "with kind '"
3562  << arith::stringifyAtomicRMWKind(getKind())
3563  << "' expects an integer type";
3564  break;
3565  default:
3566  break;
3567  }
3568  return success();
3569 }
3570 
3571 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3572  /// atomicrmw(memrefcast) -> atomicrmw
3573  if (succeeded(foldMemRefCast(*this, getValue())))
3574  return getResult();
3575  return OpFoldResult();
3576 }
3577 
3578 //===----------------------------------------------------------------------===//
3579 // TableGen'd op method definitions
3580 //===----------------------------------------------------------------------===//
3581 
3582 #define GET_OP_CLASSES
3583 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition: MemRefOps.cpp:96
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1470
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2078
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:377
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3301
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:358
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2859
static LogicalResult FoldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
Definition: MemRefOps.cpp:761
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1308
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
Definition: MemRefOps.cpp:2892
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3030
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1484
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:848
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3094
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:400
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2870
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:833
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2171
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2366
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:129
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:186
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:140
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:250
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:193
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:226
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:767
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:829
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:700
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:612
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:97
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:68
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3062
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:2997
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:448
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:451
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:408
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:411
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2534
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3246
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3247
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3209
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3210
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.