MLIR  21.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::memref;
29 
30 /// Materialize a single constant operation from a given attribute value with
31 /// the desired resultant type.
33  Attribute value, Type type,
34  Location loc) {
35  return arith::ConstantOp::materialize(builder, value, type, loc);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Common canonicalization pattern support logic
40 //===----------------------------------------------------------------------===//
41 
42 /// This is a common class used for patterns of the form
43 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
44 /// into the root operation directly.
45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46  bool folded = false;
47  for (OpOperand &operand : op->getOpOperands()) {
48  auto cast = operand.get().getDefiningOp<CastOp>();
49  if (cast && operand.get() != inner &&
50  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51  operand.set(cast.getOperand());
52  folded = true;
53  }
54  }
55  return success(folded);
56 }
57 
58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
59 /// type.
61  if (auto memref = llvm::dyn_cast<MemRefType>(type))
62  return RankedTensorType::get(memref.getShape(), memref.getElementType());
63  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
64  return UnrankedTensorType::get(memref.getElementType());
65  return NoneType::get(type.getContext());
66 }
67 
69  int64_t dim) {
70  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that sets values[i] to constValues[i] if the latter is a
91 /// static value, as indicated by ShapedType::kDynamic.
92 ///
93 /// If constValues[i] is dynamic, tries to extract a constant value from
94 /// value[i] to allow for additional folding opportunities. Also convertes all
95 /// existing attributes to index attributes. (They may be i64 attributes.)
97  ArrayRef<int64_t> constValues) {
98  assert(constValues.size() == values.size() &&
99  "incorrect number of const values");
100  for (auto [i, cstVal] : llvm::enumerate(constValues)) {
101  Builder builder(values[i].getContext());
102  if (ShapedType::isStatic(cstVal)) {
103  // Constant value is known, use it directly.
104  values[i] = builder.getIndexAttr(cstVal);
105  continue;
106  }
107  if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
108  // Try to extract a constant or convert an existing to index.
109  values[i] = builder.getIndexAttr(*cst);
110  }
111  }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // AllocOp / AllocaOp
116 //===----------------------------------------------------------------------===//
117 
118 void AllocOp::getAsmResultNames(
119  function_ref<void(Value, StringRef)> setNameFn) {
120  setNameFn(getResult(), "alloc");
121 }
122 
123 void AllocaOp::getAsmResultNames(
124  function_ref<void(Value, StringRef)> setNameFn) {
125  setNameFn(getResult(), "alloca");
126 }
127 
128 template <typename AllocLikeOp>
129 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
130  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
131  "applies to only alloc or alloca");
132  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
133  if (!memRefType)
134  return op.emitOpError("result must be a memref");
135 
136  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
137  return op.emitOpError("dimension operand count does not equal memref "
138  "dynamic dimension count");
139 
140  unsigned numSymbols = 0;
141  if (!memRefType.getLayout().isIdentity())
142  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
143  if (op.getSymbolOperands().size() != numSymbols)
144  return op.emitOpError("symbol operand count does not equal memref symbol "
145  "count: expected ")
146  << numSymbols << ", got " << op.getSymbolOperands().size();
147 
148  return success();
149 }
150 
151 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
152 
153 LogicalResult AllocaOp::verify() {
154  // An alloca op needs to have an ancestor with an allocation scope trait.
155  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
156  return emitOpError(
157  "requires an ancestor op with AutomaticAllocationScope trait");
158 
159  return verifyAllocLikeOp(*this);
160 }
161 
162 namespace {
163 /// Fold constant dimensions into an alloc like operation.
164 template <typename AllocLikeOp>
165 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
167 
168  LogicalResult matchAndRewrite(AllocLikeOp alloc,
169  PatternRewriter &rewriter) const override {
170  // Check to see if any dimensions operands are constants. If so, we can
171  // substitute and drop them.
172  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
173  APInt constSizeArg;
174  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
175  return false;
176  return constSizeArg.isNonNegative();
177  }))
178  return failure();
179 
180  auto memrefType = alloc.getType();
181 
182  // Ok, we have one or more constant operands. Collect the non-constant ones
183  // and keep track of the resultant memref type to build.
184  SmallVector<int64_t, 4> newShapeConstants;
185  newShapeConstants.reserve(memrefType.getRank());
186  SmallVector<Value, 4> dynamicSizes;
187 
188  unsigned dynamicDimPos = 0;
189  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
190  int64_t dimSize = memrefType.getDimSize(dim);
191  // If this is already static dimension, keep it.
192  if (ShapedType::isStatic(dimSize)) {
193  newShapeConstants.push_back(dimSize);
194  continue;
195  }
196  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
197  APInt constSizeArg;
198  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
199  constSizeArg.isNonNegative()) {
200  // Dynamic shape dimension will be folded.
201  newShapeConstants.push_back(constSizeArg.getZExtValue());
202  } else {
203  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
204  newShapeConstants.push_back(ShapedType::kDynamic);
205  dynamicSizes.push_back(dynamicSize);
206  }
207  dynamicDimPos++;
208  }
209 
210  // Create new memref type (which will have fewer dynamic dimensions).
211  MemRefType newMemRefType =
212  MemRefType::Builder(memrefType).setShape(newShapeConstants);
213  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
214 
215  // Create and insert the alloc op for the new memref.
216  auto newAlloc = rewriter.create<AllocLikeOp>(
217  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
218  alloc.getAlignmentAttr());
219  // Insert a cast so we have the same type as the old alloc.
220  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
221  return success();
222  }
223 };
224 
225 /// Fold alloc operations with no users or only store and dealloc uses.
226 template <typename T>
227 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
229 
230  LogicalResult matchAndRewrite(T alloc,
231  PatternRewriter &rewriter) const override {
232  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
233  if (auto storeOp = dyn_cast<StoreOp>(op))
234  return storeOp.getValue() == alloc;
235  return !isa<DeallocOp>(op);
236  }))
237  return failure();
238 
239  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
240  rewriter.eraseOp(user);
241 
242  rewriter.eraseOp(alloc);
243  return success();
244  }
245 };
246 } // namespace
247 
248 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
249  MLIRContext *context) {
250  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
251 }
252 
253 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
254  MLIRContext *context) {
255  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
256  context);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // ReallocOp
261 //===----------------------------------------------------------------------===//
262 
263 LogicalResult ReallocOp::verify() {
264  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
265  MemRefType resultType = getType();
266 
267  // The source memref should have identity layout (or none).
268  if (!sourceType.getLayout().isIdentity())
269  return emitError("unsupported layout for source memref type ")
270  << sourceType;
271 
272  // The result memref should have identity layout (or none).
273  if (!resultType.getLayout().isIdentity())
274  return emitError("unsupported layout for result memref type ")
275  << resultType;
276 
277  // The source memref and the result memref should be in the same memory space.
278  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
279  return emitError("different memory spaces specified for source memref "
280  "type ")
281  << sourceType << " and result memref type " << resultType;
282 
283  // The source memref and the result memref should have the same element type.
284  if (sourceType.getElementType() != resultType.getElementType())
285  return emitError("different element types specified for source memref "
286  "type ")
287  << sourceType << " and result memref type " << resultType;
288 
289  // Verify that we have the dynamic dimension operand when it is needed.
290  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
291  return emitError("missing dimension operand for result type ")
292  << resultType;
293  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
294  return emitError("unnecessary dimension operand for result type ")
295  << resultType;
296 
297  return success();
298 }
299 
300 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
301  MLIRContext *context) {
302  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // AllocaScopeOp
307 //===----------------------------------------------------------------------===//
308 
310  bool printBlockTerminators = false;
311 
312  p << ' ';
313  if (!getResults().empty()) {
314  p << " -> (" << getResultTypes() << ")";
315  printBlockTerminators = true;
316  }
317  p << ' ';
318  p.printRegion(getBodyRegion(),
319  /*printEntryBlockArgs=*/false,
320  /*printBlockTerminators=*/printBlockTerminators);
321  p.printOptionalAttrDict((*this)->getAttrs());
322 }
323 
324 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
325  // Create a region for the body.
326  result.regions.reserve(1);
327  Region *bodyRegion = result.addRegion();
328 
329  // Parse optional results type list.
330  if (parser.parseOptionalArrowTypeList(result.types))
331  return failure();
332 
333  // Parse the body region.
334  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
335  return failure();
336  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
337  result.location);
338 
339  // Parse the optional attribute list.
340  if (parser.parseOptionalAttrDict(result.attributes))
341  return failure();
342 
343  return success();
344 }
345 
346 void AllocaScopeOp::getSuccessorRegions(
348  if (!point.isParent()) {
349  regions.push_back(RegionSuccessor(getResults()));
350  return;
351  }
352 
353  regions.push_back(RegionSuccessor(&getBodyRegion()));
354 }
355 
356 /// Given an operation, return whether this op is guaranteed to
357 /// allocate an AutomaticAllocationScopeResource
359  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
360  if (!interface)
361  return false;
362  for (auto res : op->getResults()) {
363  if (auto effect =
364  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
365  if (isa<SideEffects::AutomaticAllocationScopeResource>(
366  effect->getResource()))
367  return true;
368  }
369  }
370  return false;
371 }
372 
373 /// Given an operation, return whether this op itself could
374 /// allocate an AutomaticAllocationScopeResource. Note that
375 /// this will not check whether an operation contained within
376 /// the op can allocate.
378  // This op itself doesn't create a stack allocation,
379  // the inner allocation should be handled separately.
381  return false;
382  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
383  if (!interface)
384  return true;
385  for (auto res : op->getResults()) {
386  if (auto effect =
387  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
388  if (isa<SideEffects::AutomaticAllocationScopeResource>(
389  effect->getResource()))
390  return true;
391  }
392  }
393  return false;
394 }
395 
396 /// Return whether this op is the last non terminating op
397 /// in a region. That is to say, it is in a one-block region
398 /// and is only followed by a terminator. This prevents
399 /// extending the lifetime of allocations.
401  return op->getBlock()->mightHaveTerminator() &&
402  op->getNextNode() == op->getBlock()->getTerminator() &&
403  op->getParentRegion()->hasOneBlock();
404 }
405 
406 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
407 /// or it contains no allocation.
408 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
410 
411  LogicalResult matchAndRewrite(AllocaScopeOp op,
412  PatternRewriter &rewriter) const override {
413  bool hasPotentialAlloca =
414  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
415  if (alloc == op)
416  return WalkResult::advance();
418  return WalkResult::interrupt();
419  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
420  return WalkResult::skip();
421  return WalkResult::advance();
422  }).wasInterrupted();
423 
424  // If this contains no potential allocation, it is always legal to
425  // inline. Otherwise, consider two conditions:
426  if (hasPotentialAlloca) {
427  // If the parent isn't an allocation scope, or we are not the last
428  // non-terminator op in the parent, we will extend the lifetime.
429  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
430  return failure();
431  if (!lastNonTerminatorInRegion(op))
432  return failure();
433  }
434 
435  Block *block = &op.getRegion().front();
436  Operation *terminator = block->getTerminator();
437  ValueRange results = terminator->getOperands();
438  rewriter.inlineBlockBefore(block, op);
439  rewriter.replaceOp(op, results);
440  rewriter.eraseOp(terminator);
441  return success();
442  }
443 };
444 
445 /// Move allocations into an allocation scope, if it is legal to
446 /// move them (e.g. their operands are available at the location
447 /// the op would be moved to).
448 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
450 
451  LogicalResult matchAndRewrite(AllocaScopeOp op,
452  PatternRewriter &rewriter) const override {
453 
454  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
455  return failure();
456 
457  Operation *lastParentWithoutScope = op->getParentOp();
458 
459  if (!lastParentWithoutScope ||
460  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
461  return failure();
462 
463  // Only apply to if this is this last non-terminator
464  // op in the block (lest lifetime be extended) of a one
465  // block region
466  if (!lastNonTerminatorInRegion(op) ||
467  !lastNonTerminatorInRegion(lastParentWithoutScope))
468  return failure();
469 
470  while (!lastParentWithoutScope->getParentOp()
472  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
473  if (!lastParentWithoutScope ||
474  !lastNonTerminatorInRegion(lastParentWithoutScope))
475  return failure();
476  }
477  assert(lastParentWithoutScope->getParentOp()
479 
480  Region *containingRegion = nullptr;
481  for (auto &r : lastParentWithoutScope->getRegions()) {
482  if (r.isAncestor(op->getParentRegion())) {
483  assert(containingRegion == nullptr &&
484  "only one region can contain the op");
485  containingRegion = &r;
486  }
487  }
488  assert(containingRegion && "op must be contained in a region");
489 
490  SmallVector<Operation *> toHoist;
491  op->walk([&](Operation *alloc) {
493  return WalkResult::skip();
494 
495  // If any operand is not defined before the location of
496  // lastParentWithoutScope (i.e. where we would hoist to), skip.
497  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
498  return containingRegion->isAncestor(v.getParentRegion());
499  }))
500  return WalkResult::skip();
501  toHoist.push_back(alloc);
502  return WalkResult::advance();
503  });
504 
505  if (toHoist.empty())
506  return failure();
507  rewriter.setInsertionPoint(lastParentWithoutScope);
508  for (auto *op : toHoist) {
509  auto *cloned = rewriter.clone(*op);
510  rewriter.replaceOp(op, cloned->getResults());
511  }
512  return success();
513  }
514 };
515 
516 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
517  MLIRContext *context) {
518  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // AssumeAlignmentOp
523 //===----------------------------------------------------------------------===//
524 
525 LogicalResult AssumeAlignmentOp::verify() {
526  if (!llvm::isPowerOf2_32(getAlignment()))
527  return emitOpError("alignment must be power of 2");
528  return success();
529 }
530 
531 void AssumeAlignmentOp::getAsmResultNames(
532  function_ref<void(Value, StringRef)> setNameFn) {
533  setNameFn(getResult(), "assume_align");
534 }
535 
536 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
537  auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
538  if (!source)
539  return {};
540  if (source.getAlignment() != getAlignment())
541  return {};
542  return getMemref();
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // CastOp
547 //===----------------------------------------------------------------------===//
548 
549 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
550  setNameFn(getResult(), "cast");
551 }
552 
553 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
554 /// source memref. This is useful to fold a memref.cast into a consuming op
555 /// and implement canonicalization patterns for ops in different dialects that
556 /// may consume the results of memref.cast operations. Such foldable memref.cast
557 /// operations are typically inserted as `view` and `subview` ops are
558 /// canonicalized, to preserve the type compatibility of their uses.
559 ///
560 /// Returns true when all conditions are met:
561 /// 1. source and result are ranked memrefs with strided semantics and same
562 /// element type and rank.
563 /// 2. each of the source's size, offset or stride has more static information
564 /// than the corresponding result's size, offset or stride.
565 ///
566 /// Example 1:
567 /// ```mlir
568 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
569 /// %2 = consumer %1 ... : memref<?x?xf32> ...
570 /// ```
571 ///
572 /// may fold into:
573 ///
574 /// ```mlir
575 /// %2 = consumer %0 ... : memref<8x16xf32> ...
576 /// ```
577 ///
578 /// Example 2:
579 /// ```
580 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
581 /// to memref<?x?xf32>
582 /// consumer %1 : memref<?x?xf32> ...
583 /// ```
584 ///
585 /// may fold into:
586 ///
587 /// ```
588 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
589 /// ```
590 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
591  MemRefType sourceType =
592  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
593  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
594 
595  // Requires ranked MemRefType.
596  if (!sourceType || !resultType)
597  return false;
598 
599  // Requires same elemental type.
600  if (sourceType.getElementType() != resultType.getElementType())
601  return false;
602 
603  // Requires same rank.
604  if (sourceType.getRank() != resultType.getRank())
605  return false;
606 
607  // Only fold casts between strided memref forms.
608  int64_t sourceOffset, resultOffset;
609  SmallVector<int64_t, 4> sourceStrides, resultStrides;
610  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
611  failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
612  return false;
613 
614  // If cast is towards more static sizes along any dimension, don't fold.
615  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
616  auto ss = std::get<0>(it), st = std::get<1>(it);
617  if (ss != st)
618  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
619  return false;
620  }
621 
622  // If cast is towards more static offset along any dimension, don't fold.
623  if (sourceOffset != resultOffset)
624  if (ShapedType::isDynamic(sourceOffset) &&
625  ShapedType::isStatic(resultOffset))
626  return false;
627 
628  // If cast is towards more static strides along any dimension, don't fold.
629  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
630  auto ss = std::get<0>(it), st = std::get<1>(it);
631  if (ss != st)
632  if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
633  return false;
634  }
635 
636  return true;
637 }
638 
639 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
640  if (inputs.size() != 1 || outputs.size() != 1)
641  return false;
642  Type a = inputs.front(), b = outputs.front();
643  auto aT = llvm::dyn_cast<MemRefType>(a);
644  auto bT = llvm::dyn_cast<MemRefType>(b);
645 
646  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
647  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
648 
649  if (aT && bT) {
650  if (aT.getElementType() != bT.getElementType())
651  return false;
652  if (aT.getLayout() != bT.getLayout()) {
653  int64_t aOffset, bOffset;
654  SmallVector<int64_t, 4> aStrides, bStrides;
655  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
656  failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
657  aStrides.size() != bStrides.size())
658  return false;
659 
660  // Strides along a dimension/offset are compatible if the value in the
661  // source memref is static and the value in the target memref is the
662  // same. They are also compatible if either one is dynamic (see
663  // description of MemRefCastOp for details).
664  auto checkCompatible = [](int64_t a, int64_t b) {
665  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
666  };
667  if (!checkCompatible(aOffset, bOffset))
668  return false;
669  for (const auto &aStride : enumerate(aStrides))
670  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
671  return false;
672  }
673  if (aT.getMemorySpace() != bT.getMemorySpace())
674  return false;
675 
676  // They must have the same rank, and any specified dimensions must match.
677  if (aT.getRank() != bT.getRank())
678  return false;
679 
680  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
681  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
682  if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
683  aDim != bDim)
684  return false;
685  }
686  return true;
687  } else {
688  if (!aT && !uaT)
689  return false;
690  if (!bT && !ubT)
691  return false;
692  // Unranked to unranked casting is unsupported
693  if (uaT && ubT)
694  return false;
695 
696  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
697  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
698  if (aEltType != bEltType)
699  return false;
700 
701  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
702  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
703  return aMemSpace == bMemSpace;
704  }
705 
706  return false;
707 }
708 
709 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
710  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // CopyOp
715 //===----------------------------------------------------------------------===//
716 
717 namespace {
718 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
719 /// and element type, the cast can be skipped. Such CastOps only cast the layout
720 /// of the type.
721 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
723 
724  LogicalResult matchAndRewrite(CopyOp copyOp,
725  PatternRewriter &rewriter) const override {
726  bool modified = false;
727 
728  // Check source.
729  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
730  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
731  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
732 
733  if (fromType && toType) {
734  if (fromType.getShape() == toType.getShape() &&
735  fromType.getElementType() == toType.getElementType()) {
736  rewriter.modifyOpInPlace(copyOp, [&] {
737  copyOp.getSourceMutable().assign(castOp.getSource());
738  });
739  modified = true;
740  }
741  }
742  }
743 
744  // Check target.
745  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
746  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
747  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
748 
749  if (fromType && toType) {
750  if (fromType.getShape() == toType.getShape() &&
751  fromType.getElementType() == toType.getElementType()) {
752  rewriter.modifyOpInPlace(copyOp, [&] {
753  copyOp.getTargetMutable().assign(castOp.getSource());
754  });
755  modified = true;
756  }
757  }
758  }
759 
760  return success(modified);
761  }
762 };
763 
764 /// Fold memref.copy(%x, %x).
765 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
767 
768  LogicalResult matchAndRewrite(CopyOp copyOp,
769  PatternRewriter &rewriter) const override {
770  if (copyOp.getSource() != copyOp.getTarget())
771  return failure();
772 
773  rewriter.eraseOp(copyOp);
774  return success();
775  }
776 };
777 
778 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
780 
781  static bool isEmptyMemRef(BaseMemRefType type) {
782  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
783  }
784 
785  LogicalResult matchAndRewrite(CopyOp copyOp,
786  PatternRewriter &rewriter) const override {
787  if (isEmptyMemRef(copyOp.getSource().getType()) ||
788  isEmptyMemRef(copyOp.getTarget().getType())) {
789  rewriter.eraseOp(copyOp);
790  return success();
791  }
792 
793  return failure();
794  }
795 };
796 } // namespace
797 
798 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
799  MLIRContext *context) {
800  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
801 }
802 
803 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
805  /// copy(memrefcast) -> copy
806  bool folded = false;
807  Operation *op = *this;
808  for (OpOperand &operand : op->getOpOperands()) {
809  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
810  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
811  operand.set(castOp.getOperand());
812  folded = true;
813  }
814  }
815  return success(folded);
816 }
817 
818 //===----------------------------------------------------------------------===//
819 // DeallocOp
820 //===----------------------------------------------------------------------===//
821 
822 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
824  /// dealloc(memrefcast) -> dealloc
825  return foldMemRefCast(*this);
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // DimOp
830 //===----------------------------------------------------------------------===//
831 
832 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
833  setNameFn(getResult(), "dim");
834 }
835 
836 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
837  int64_t index) {
838  auto loc = result.location;
839  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
840  build(builder, result, source, indexValue);
841 }
842 
843 std::optional<int64_t> DimOp::getConstantIndex() {
844  return getConstantIntValue(getIndex());
845 }
846 
847 Speculation::Speculatability DimOp::getSpeculatability() {
848  auto constantIndex = getConstantIndex();
849  if (!constantIndex)
851 
852  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
853  if (!rankedSourceType)
855 
856  if (rankedSourceType.getRank() <= constantIndex)
858 
860 }
861 
862 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
863  SetIntLatticeFn setResultRange) {
864  setResultRange(getResult(),
865  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
866 }
867 
868 /// Return a map with key being elements in `vals` and data being number of
869 /// occurences of it. Use std::map, since the `vals` here are strides and the
870 /// dynamic stride value is the same as the tombstone value for
871 /// `DenseMap<int64_t>`.
872 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
873  std::map<int64_t, unsigned> numOccurences;
874  for (auto val : vals)
875  numOccurences[val]++;
876  return numOccurences;
877 }
878 
879 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
880 /// to be a subset of `originalType` with some `1` entries erased, return the
881 /// set of indices that specifies which of the entries of `originalShape` are
882 /// dropped to obtain `reducedShape`.
883 /// This accounts for cases where there are multiple unit-dims, but only a
884 /// subset of those are dropped. For MemRefTypes these can be disambiguated
885 /// using the strides. If a dimension is dropped the stride must be dropped too.
886 static FailureOr<llvm::SmallBitVector>
887 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
888  ArrayRef<OpFoldResult> sizes) {
889  llvm::SmallBitVector unusedDims(originalType.getRank());
890  if (originalType.getRank() == reducedType.getRank())
891  return unusedDims;
892 
893  for (const auto &dim : llvm::enumerate(sizes))
894  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
895  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
896  unusedDims.set(dim.index());
897 
898  // Early exit for the case where the number of unused dims matches the number
899  // of ranks reduced.
900  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
901  originalType.getRank())
902  return unusedDims;
903 
904  SmallVector<int64_t> originalStrides, candidateStrides;
905  int64_t originalOffset, candidateOffset;
906  if (failed(
907  originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
908  failed(
909  reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
910  return failure();
911 
912  // For memrefs, a dimension is truly dropped if its corresponding stride is
913  // also dropped. This is particularly important when more than one of the dims
914  // is 1. Track the number of occurences of the strides in the original type
915  // and the candidate type. For each unused dim that stride should not be
916  // present in the candidate type. Note that there could be multiple dimensions
917  // that have the same size. We dont need to exactly figure out which dim
918  // corresponds to which stride, we just need to verify that the number of
919  // reptitions of a stride in the original + number of unused dims with that
920  // stride == number of repititions of a stride in the candidate.
921  std::map<int64_t, unsigned> currUnaccountedStrides =
922  getNumOccurences(originalStrides);
923  std::map<int64_t, unsigned> candidateStridesNumOccurences =
924  getNumOccurences(candidateStrides);
925  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
926  if (!unusedDims.test(dim))
927  continue;
928  int64_t originalStride = originalStrides[dim];
929  if (currUnaccountedStrides[originalStride] >
930  candidateStridesNumOccurences[originalStride]) {
931  // This dim can be treated as dropped.
932  currUnaccountedStrides[originalStride]--;
933  continue;
934  }
935  if (currUnaccountedStrides[originalStride] ==
936  candidateStridesNumOccurences[originalStride]) {
937  // The stride for this is not dropped. Keep as is.
938  unusedDims.reset(dim);
939  continue;
940  }
941  if (currUnaccountedStrides[originalStride] <
942  candidateStridesNumOccurences[originalStride]) {
943  // This should never happen. Cant have a stride in the reduced rank type
944  // that wasnt in the original one.
945  return failure();
946  }
947  }
948 
949  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
950  originalType.getRank())
951  return failure();
952  return unusedDims;
953 }
954 
955 llvm::SmallBitVector SubViewOp::getDroppedDims() {
956  MemRefType sourceType = getSourceType();
957  MemRefType resultType = getType();
958  FailureOr<llvm::SmallBitVector> unusedDims =
959  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
960  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
961  return *unusedDims;
962 }
963 
964 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
965  // All forms of folding require a known index.
966  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
967  if (!index)
968  return {};
969 
970  // Folding for unranked types (UnrankedMemRefType) is not supported.
971  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
972  if (!memrefType)
973  return {};
974 
975  // Out of bound indices produce undefined behavior but are still valid IR.
976  // Don't choke on them.
977  int64_t indexVal = index.getInt();
978  if (indexVal < 0 || indexVal >= memrefType.getRank())
979  return {};
980 
981  // Fold if the shape extent along the given index is known.
982  if (!memrefType.isDynamicDim(index.getInt())) {
983  Builder builder(getContext());
984  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
985  }
986 
987  // The size at the given index is now known to be a dynamic size.
988  unsigned unsignedIndex = index.getValue().getZExtValue();
989 
990  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
991  Operation *definingOp = getSource().getDefiningOp();
992 
993  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
994  return *(alloc.getDynamicSizes().begin() +
995  memrefType.getDynamicDimIndex(unsignedIndex));
996 
997  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
998  return *(alloca.getDynamicSizes().begin() +
999  memrefType.getDynamicDimIndex(unsignedIndex));
1000 
1001  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1002  return *(view.getDynamicSizes().begin() +
1003  memrefType.getDynamicDimIndex(unsignedIndex));
1004 
1005  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1006  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1007  unsigned resultIndex = 0;
1008  unsigned sourceRank = subview.getSourceType().getRank();
1009  unsigned sourceIndex = 0;
1010  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1011  if (unusedDims.test(i))
1012  continue;
1013  if (resultIndex == unsignedIndex) {
1014  sourceIndex = i;
1015  break;
1016  }
1017  resultIndex++;
1018  }
1019  assert(subview.isDynamicSize(sourceIndex) &&
1020  "expected dynamic subview size");
1021  return subview.getDynamicSize(sourceIndex);
1022  }
1023 
1024  if (auto sizeInterface =
1025  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1026  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1027  "Expected dynamic subview size");
1028  return sizeInterface.getDynamicSize(unsignedIndex);
1029  }
1030 
1031  // dim(memrefcast) -> dim
1032  if (succeeded(foldMemRefCast(*this)))
1033  return getResult();
1034 
1035  return {};
1036 }
1037 
1038 namespace {
1039 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1040 /// operand.
1041 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1043 
1044  LogicalResult matchAndRewrite(DimOp dim,
1045  PatternRewriter &rewriter) const override {
1046  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1047 
1048  if (!reshape)
1049  return rewriter.notifyMatchFailure(
1050  dim, "Dim op is not defined by a reshape op.");
1051 
1052  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1053  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1054  // cheaply check that either of the following conditions hold:
1055  // 1. dim.getIndex() is defined in the same block as reshape but before
1056  // reshape.
1057  // 2. dim.getIndex() is defined in a parent block of
1058  // reshape.
1059 
1060  // Check condition 1
1061  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1062  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1063  if (reshape->isBeforeInBlock(definingOp)) {
1064  return rewriter.notifyMatchFailure(
1065  dim,
1066  "dim.getIndex is not defined before reshape in the same block.");
1067  }
1068  } // else dim.getIndex is a block argument to reshape->getBlock and
1069  // dominates reshape
1070  } // Check condition 2
1071  else if (dim->getBlock() != reshape->getBlock() &&
1072  !dim.getIndex().getParentRegion()->isProperAncestor(
1073  reshape->getParentRegion())) {
1074  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1075  // already know dim.getIndex() dominates reshape without calling
1076  // `isProperAncestor`
1077  return rewriter.notifyMatchFailure(
1078  dim, "dim.getIndex does not dominate reshape.");
1079  }
1080 
1081  // Place the load directly after the reshape to ensure that the shape memref
1082  // was not mutated.
1083  rewriter.setInsertionPointAfter(reshape);
1084  Location loc = dim.getLoc();
1085  Value load =
1086  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1087  if (load.getType() != dim.getType())
1088  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1089  rewriter.replaceOp(dim, load);
1090  return success();
1091  }
1092 };
1093 
1094 } // namespace
1095 
1096 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1097  MLIRContext *context) {
1098  results.add<DimOfMemRefReshape>(context);
1099 }
1100 
1101 // ---------------------------------------------------------------------------
1102 // DmaStartOp
1103 // ---------------------------------------------------------------------------
1104 
1105 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1106  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1107  ValueRange destIndices, Value numElements,
1108  Value tagMemRef, ValueRange tagIndices, Value stride,
1109  Value elementsPerStride) {
1110  result.addOperands(srcMemRef);
1111  result.addOperands(srcIndices);
1112  result.addOperands(destMemRef);
1113  result.addOperands(destIndices);
1114  result.addOperands({numElements, tagMemRef});
1115  result.addOperands(tagIndices);
1116  if (stride)
1117  result.addOperands({stride, elementsPerStride});
1118 }
1119 
1121  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1122  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1123  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1124  if (isStrided())
1125  p << ", " << getStride() << ", " << getNumElementsPerStride();
1126 
1127  p.printOptionalAttrDict((*this)->getAttrs());
1128  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1129  << ", " << getTagMemRef().getType();
1130 }
1131 
1132 // Parse DmaStartOp.
1133 // Ex:
1134 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1135 // %tag[%index], %stride, %num_elt_per_stride :
1136 // : memref<3076 x f32, 0>,
1137 // memref<1024 x f32, 2>,
1138 // memref<1 x i32>
1139 //
1140 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1141  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1143  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1145  OpAsmParser::UnresolvedOperand numElementsInfo;
1146  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1149 
1150  SmallVector<Type, 3> types;
1151  auto indexType = parser.getBuilder().getIndexType();
1152 
1153  // Parse and resolve the following list of operands:
1154  // *) source memref followed by its indices (in square brackets).
1155  // *) destination memref followed by its indices (in square brackets).
1156  // *) dma size in KiB.
1157  if (parser.parseOperand(srcMemRefInfo) ||
1158  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1159  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1160  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1161  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1162  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1163  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1164  return failure();
1165 
1166  // Parse optional stride and elements per stride.
1167  if (parser.parseTrailingOperandList(strideInfo))
1168  return failure();
1169 
1170  bool isStrided = strideInfo.size() == 2;
1171  if (!strideInfo.empty() && !isStrided) {
1172  return parser.emitError(parser.getNameLoc(),
1173  "expected two stride related operands");
1174  }
1175 
1176  if (parser.parseColonTypeList(types))
1177  return failure();
1178  if (types.size() != 3)
1179  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1180 
1181  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1182  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1183  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1184  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1185  // size should be an index.
1186  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1187  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1188  // tag indices should be index.
1189  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1190  return failure();
1191 
1192  if (isStrided) {
1193  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1194  return failure();
1195  }
1196 
1197  return success();
1198 }
1199 
1200 LogicalResult DmaStartOp::verify() {
1201  unsigned numOperands = getNumOperands();
1202 
1203  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1204  // the number of elements.
1205  if (numOperands < 4)
1206  return emitOpError("expected at least 4 operands");
1207 
1208  // Check types of operands. The order of these calls is important: the later
1209  // calls rely on some type properties to compute the operand position.
1210  // 1. Source memref.
1211  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1212  return emitOpError("expected source to be of memref type");
1213  if (numOperands < getSrcMemRefRank() + 4)
1214  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1215  << " operands";
1216  if (!getSrcIndices().empty() &&
1217  !llvm::all_of(getSrcIndices().getTypes(),
1218  [](Type t) { return t.isIndex(); }))
1219  return emitOpError("expected source indices to be of index type");
1220 
1221  // 2. Destination memref.
1222  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1223  return emitOpError("expected destination to be of memref type");
1224  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1225  if (numOperands < numExpectedOperands)
1226  return emitOpError() << "expected at least " << numExpectedOperands
1227  << " operands";
1228  if (!getDstIndices().empty() &&
1229  !llvm::all_of(getDstIndices().getTypes(),
1230  [](Type t) { return t.isIndex(); }))
1231  return emitOpError("expected destination indices to be of index type");
1232 
1233  // 3. Number of elements.
1234  if (!getNumElements().getType().isIndex())
1235  return emitOpError("expected num elements to be of index type");
1236 
1237  // 4. Tag memref.
1238  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1239  return emitOpError("expected tag to be of memref type");
1240  numExpectedOperands += getTagMemRefRank();
1241  if (numOperands < numExpectedOperands)
1242  return emitOpError() << "expected at least " << numExpectedOperands
1243  << " operands";
1244  if (!getTagIndices().empty() &&
1245  !llvm::all_of(getTagIndices().getTypes(),
1246  [](Type t) { return t.isIndex(); }))
1247  return emitOpError("expected tag indices to be of index type");
1248 
1249  // Optional stride-related operands must be either both present or both
1250  // absent.
1251  if (numOperands != numExpectedOperands &&
1252  numOperands != numExpectedOperands + 2)
1253  return emitOpError("incorrect number of operands");
1254 
1255  // 5. Strides.
1256  if (isStrided()) {
1257  if (!getStride().getType().isIndex() ||
1258  !getNumElementsPerStride().getType().isIndex())
1259  return emitOpError(
1260  "expected stride and num elements per stride to be of type index");
1261  }
1262 
1263  return success();
1264 }
1265 
1266 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1267  SmallVectorImpl<OpFoldResult> &results) {
1268  /// dma_start(memrefcast) -> dma_start
1269  return foldMemRefCast(*this);
1270 }
1271 
1272 // ---------------------------------------------------------------------------
1273 // DmaWaitOp
1274 // ---------------------------------------------------------------------------
1275 
1276 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1277  SmallVectorImpl<OpFoldResult> &results) {
1278  /// dma_wait(memrefcast) -> dma_wait
1279  return foldMemRefCast(*this);
1280 }
1281 
1282 LogicalResult DmaWaitOp::verify() {
1283  // Check that the number of tag indices matches the tagMemRef rank.
1284  unsigned numTagIndices = getTagIndices().size();
1285  unsigned tagMemRefRank = getTagMemRefRank();
1286  if (numTagIndices != tagMemRefRank)
1287  return emitOpError() << "expected tagIndices to have the same number of "
1288  "elements as the tagMemRef rank, expected "
1289  << tagMemRefRank << ", but got " << numTagIndices;
1290  return success();
1291 }
1292 
1293 //===----------------------------------------------------------------------===//
1294 // ExtractAlignedPointerAsIndexOp
1295 //===----------------------------------------------------------------------===//
1296 
1297 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1298  function_ref<void(Value, StringRef)> setNameFn) {
1299  setNameFn(getResult(), "intptr");
1300 }
1301 
1302 //===----------------------------------------------------------------------===//
1303 // ExtractStridedMetadataOp
1304 //===----------------------------------------------------------------------===//
1305 
1306 /// The number and type of the results are inferred from the
1307 /// shape of the source.
1308 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1309  MLIRContext *context, std::optional<Location> location,
1310  ExtractStridedMetadataOp::Adaptor adaptor,
1311  SmallVectorImpl<Type> &inferredReturnTypes) {
1312  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1313  if (!sourceType)
1314  return failure();
1315 
1316  unsigned sourceRank = sourceType.getRank();
1317  IndexType indexType = IndexType::get(context);
1318  auto memrefType =
1319  MemRefType::get({}, sourceType.getElementType(),
1320  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1321  // Base.
1322  inferredReturnTypes.push_back(memrefType);
1323  // Offset.
1324  inferredReturnTypes.push_back(indexType);
1325  // Sizes and strides.
1326  for (unsigned i = 0; i < sourceRank * 2; ++i)
1327  inferredReturnTypes.push_back(indexType);
1328  return success();
1329 }
1330 
1331 void ExtractStridedMetadataOp::getAsmResultNames(
1332  function_ref<void(Value, StringRef)> setNameFn) {
1333  setNameFn(getBaseBuffer(), "base_buffer");
1334  setNameFn(getOffset(), "offset");
1335  // For multi-result to work properly with pretty names and packed syntax `x:3`
1336  // we can only give a pretty name to the first value in the pack.
1337  if (!getSizes().empty()) {
1338  setNameFn(getSizes().front(), "sizes");
1339  setNameFn(getStrides().front(), "strides");
1340  }
1341 }
1342 
1343 /// Helper function to perform the replacement of all constant uses of `values`
1344 /// by a materialized constant extracted from `maybeConstants`.
1345 /// `values` and `maybeConstants` are expected to have the same size.
1346 template <typename Container>
1347 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1348  Container values,
1349  ArrayRef<OpFoldResult> maybeConstants) {
1350  assert(values.size() == maybeConstants.size() &&
1351  " expected values and maybeConstants of the same size");
1352  bool atLeastOneReplacement = false;
1353  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1354  // Don't materialize a constant if there are no uses: this would indice
1355  // infinite loops in the driver.
1356  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1357  continue;
1358  assert(isa<Attribute>(maybeConstant) &&
1359  "The constified value should be either unchanged (i.e., == result) "
1360  "or a constant");
1361  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1362  loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1363  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1364  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1365  // yet.
1366  op->replaceUsesOfWith(result, constantVal);
1367  atLeastOneReplacement = true;
1368  }
1369  }
1370  return atLeastOneReplacement;
1371 }
1372 
1373 LogicalResult
1374 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1375  SmallVectorImpl<OpFoldResult> &results) {
1376  OpBuilder builder(*this);
1377 
1378  bool atLeastOneReplacement = replaceConstantUsesOf(
1379  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1380  getConstifiedMixedOffset());
1381  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1382  getConstifiedMixedSizes());
1383  atLeastOneReplacement |= replaceConstantUsesOf(
1384  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1385 
1386  return success(atLeastOneReplacement);
1387 }
1388 
1389 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1390  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1391  constifyIndexValues(values, getSource().getType().getShape());
1392  return values;
1393 }
1394 
1396 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1397  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1398  SmallVector<int64_t> staticValues;
1399  int64_t unused;
1400  LogicalResult status =
1401  getSource().getType().getStridesAndOffset(staticValues, unused);
1402  (void)status;
1403  assert(succeeded(status) && "could not get strides from type");
1404  constifyIndexValues(values, staticValues);
1405  return values;
1406 }
1407 
1408 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1409  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1410  SmallVector<OpFoldResult> values(1, offsetOfr);
1411  SmallVector<int64_t> staticValues, unused;
1412  int64_t offset;
1413  LogicalResult status =
1414  getSource().getType().getStridesAndOffset(unused, offset);
1415  (void)status;
1416  assert(succeeded(status) && "could not get offset from type");
1417  staticValues.push_back(offset);
1418  constifyIndexValues(values, staticValues);
1419  return values[0];
1420 }
1421 
1422 //===----------------------------------------------------------------------===//
1423 // GenericAtomicRMWOp
1424 //===----------------------------------------------------------------------===//
1425 
1426 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1427  Value memref, ValueRange ivs) {
1428  OpBuilder::InsertionGuard g(builder);
1429  result.addOperands(memref);
1430  result.addOperands(ivs);
1431 
1432  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1433  Type elementType = memrefType.getElementType();
1434  result.addTypes(elementType);
1435 
1436  Region *bodyRegion = result.addRegion();
1437  builder.createBlock(bodyRegion);
1438  bodyRegion->addArgument(elementType, memref.getLoc());
1439  }
1440 }
1441 
1442 LogicalResult GenericAtomicRMWOp::verify() {
1443  auto &body = getRegion();
1444  if (body.getNumArguments() != 1)
1445  return emitOpError("expected single number of entry block arguments");
1446 
1447  if (getResult().getType() != body.getArgument(0).getType())
1448  return emitOpError("expected block argument of the same type result type");
1449 
1450  bool hasSideEffects =
1451  body.walk([&](Operation *nestedOp) {
1452  if (isMemoryEffectFree(nestedOp))
1453  return WalkResult::advance();
1454  nestedOp->emitError(
1455  "body of 'memref.generic_atomic_rmw' should contain "
1456  "only operations with no side effects");
1457  return WalkResult::interrupt();
1458  })
1459  .wasInterrupted();
1460  return hasSideEffects ? failure() : success();
1461 }
1462 
1463 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1464  OperationState &result) {
1466  Type memrefType;
1468 
1469  Type indexType = parser.getBuilder().getIndexType();
1470  if (parser.parseOperand(memref) ||
1472  parser.parseColonType(memrefType) ||
1473  parser.resolveOperand(memref, memrefType, result.operands) ||
1474  parser.resolveOperands(ivs, indexType, result.operands))
1475  return failure();
1476 
1477  Region *body = result.addRegion();
1478  if (parser.parseRegion(*body, {}) ||
1479  parser.parseOptionalAttrDict(result.attributes))
1480  return failure();
1481  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1482  return success();
1483 }
1484 
1486  p << ' ' << getMemref() << "[" << getIndices()
1487  << "] : " << getMemref().getType() << ' ';
1488  p.printRegion(getRegion());
1489  p.printOptionalAttrDict((*this)->getAttrs());
1490 }
1491 
1492 //===----------------------------------------------------------------------===//
1493 // AtomicYieldOp
1494 //===----------------------------------------------------------------------===//
1495 
1496 LogicalResult AtomicYieldOp::verify() {
1497  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1498  Type resultType = getResult().getType();
1499  if (parentType != resultType)
1500  return emitOpError() << "types mismatch between yield op: " << resultType
1501  << " and its parent: " << parentType;
1502  return success();
1503 }
1504 
1505 //===----------------------------------------------------------------------===//
1506 // GlobalOp
1507 //===----------------------------------------------------------------------===//
1508 
1510  TypeAttr type,
1511  Attribute initialValue) {
1512  p << type;
1513  if (!op.isExternal()) {
1514  p << " = ";
1515  if (op.isUninitialized())
1516  p << "uninitialized";
1517  else
1518  p.printAttributeWithoutType(initialValue);
1519  }
1520 }
1521 
1522 static ParseResult
1524  Attribute &initialValue) {
1525  Type type;
1526  if (parser.parseType(type))
1527  return failure();
1528 
1529  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1530  if (!memrefType || !memrefType.hasStaticShape())
1531  return parser.emitError(parser.getNameLoc())
1532  << "type should be static shaped memref, but got " << type;
1533  typeAttr = TypeAttr::get(type);
1534 
1535  if (parser.parseOptionalEqual())
1536  return success();
1537 
1538  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1539  initialValue = UnitAttr::get(parser.getContext());
1540  return success();
1541  }
1542 
1543  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1544  if (parser.parseAttribute(initialValue, tensorType))
1545  return failure();
1546  if (!llvm::isa<ElementsAttr>(initialValue))
1547  return parser.emitError(parser.getNameLoc())
1548  << "initial value should be a unit or elements attribute";
1549  return success();
1550 }
1551 
1552 LogicalResult GlobalOp::verify() {
1553  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1554  if (!memrefType || !memrefType.hasStaticShape())
1555  return emitOpError("type should be static shaped memref, but got ")
1556  << getType();
1557 
1558  // Verify that the initial value, if present, is either a unit attribute or
1559  // an elements attribute.
1560  if (getInitialValue().has_value()) {
1561  Attribute initValue = getInitialValue().value();
1562  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1563  return emitOpError("initial value should be a unit or elements "
1564  "attribute, but got ")
1565  << initValue;
1566 
1567  // Check that the type of the initial value is compatible with the type of
1568  // the global variable.
1569  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1570  // Check the element types match.
1571  auto initElementType =
1572  cast<TensorType>(elementsAttr.getType()).getElementType();
1573  auto memrefElementType = memrefType.getElementType();
1574 
1575  if (initElementType != memrefElementType)
1576  return emitOpError("initial value element expected to be of type ")
1577  << memrefElementType << ", but was of type " << initElementType;
1578 
1579  // Check the shapes match, given that memref globals can only produce
1580  // statically shaped memrefs and elements literal type must have a static
1581  // shape we can assume both types are shaped.
1582  auto initShape = elementsAttr.getShapedType().getShape();
1583  auto memrefShape = memrefType.getShape();
1584  if (initShape != memrefShape)
1585  return emitOpError("initial value shape expected to be ")
1586  << memrefShape << " but was " << initShape;
1587  }
1588  }
1589 
1590  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1591  uint64_t alignment = *alignAttr;
1592 
1593  if (!llvm::isPowerOf2_64(alignment))
1594  return emitError() << "alignment attribute value " << alignment
1595  << " is not a power of 2";
1596  }
1597 
1598  // TODO: verify visibility for declarations.
1599  return success();
1600 }
1601 
1602 ElementsAttr GlobalOp::getConstantInitValue() {
1603  auto initVal = getInitialValue();
1604  if (getConstant() && initVal.has_value())
1605  return llvm::cast<ElementsAttr>(initVal.value());
1606  return {};
1607 }
1608 
1609 //===----------------------------------------------------------------------===//
1610 // GetGlobalOp
1611 //===----------------------------------------------------------------------===//
1612 
1613 LogicalResult
1614 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1615  // Verify that the result type is same as the type of the referenced
1616  // memref.global op.
1617  auto global =
1618  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1619  if (!global)
1620  return emitOpError("'")
1621  << getName() << "' does not reference a valid global memref";
1622 
1623  Type resultType = getResult().getType();
1624  if (global.getType() != resultType)
1625  return emitOpError("result type ")
1626  << resultType << " does not match type " << global.getType()
1627  << " of the global memref @" << getName();
1628  return success();
1629 }
1630 
1631 //===----------------------------------------------------------------------===//
1632 // LoadOp
1633 //===----------------------------------------------------------------------===//
1634 
1635 LogicalResult LoadOp::verify() {
1636  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1637  return emitOpError("incorrect number of indices for load, expected ")
1638  << getMemRefType().getRank() << " but got " << getIndices().size();
1639  }
1640  return success();
1641 }
1642 
1643 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1644  /// load(memrefcast) -> load
1645  if (succeeded(foldMemRefCast(*this)))
1646  return getResult();
1647  return OpFoldResult();
1648 }
1649 
1650 //===----------------------------------------------------------------------===//
1651 // MemorySpaceCastOp
1652 //===----------------------------------------------------------------------===//
1653 
1654 void MemorySpaceCastOp::getAsmResultNames(
1655  function_ref<void(Value, StringRef)> setNameFn) {
1656  setNameFn(getResult(), "memspacecast");
1657 }
1658 
1659 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1660  if (inputs.size() != 1 || outputs.size() != 1)
1661  return false;
1662  Type a = inputs.front(), b = outputs.front();
1663  auto aT = llvm::dyn_cast<MemRefType>(a);
1664  auto bT = llvm::dyn_cast<MemRefType>(b);
1665 
1666  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1667  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1668 
1669  if (aT && bT) {
1670  if (aT.getElementType() != bT.getElementType())
1671  return false;
1672  if (aT.getLayout() != bT.getLayout())
1673  return false;
1674  if (aT.getShape() != bT.getShape())
1675  return false;
1676  return true;
1677  }
1678  if (uaT && ubT) {
1679  return uaT.getElementType() == ubT.getElementType();
1680  }
1681  return false;
1682 }
1683 
1684 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1685  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1686  // t2)
1687  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1688  getSourceMutable().assign(parentCast.getSource());
1689  return getResult();
1690  }
1691  return Value{};
1692 }
1693 
1694 //===----------------------------------------------------------------------===//
1695 // PrefetchOp
1696 //===----------------------------------------------------------------------===//
1697 
1699  p << " " << getMemref() << '[';
1701  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1702  p << ", locality<" << getLocalityHint();
1703  p << ">, " << (getIsDataCache() ? "data" : "instr");
1705  (*this)->getAttrs(),
1706  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1707  p << " : " << getMemRefType();
1708 }
1709 
1710 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1711  OpAsmParser::UnresolvedOperand memrefInfo;
1713  IntegerAttr localityHint;
1714  MemRefType type;
1715  StringRef readOrWrite, cacheType;
1716 
1717  auto indexTy = parser.getBuilder().getIndexType();
1718  auto i32Type = parser.getBuilder().getIntegerType(32);
1719  if (parser.parseOperand(memrefInfo) ||
1720  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1721  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1722  parser.parseComma() || parser.parseKeyword("locality") ||
1723  parser.parseLess() ||
1724  parser.parseAttribute(localityHint, i32Type, "localityHint",
1725  result.attributes) ||
1726  parser.parseGreater() || parser.parseComma() ||
1727  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1728  parser.resolveOperand(memrefInfo, type, result.operands) ||
1729  parser.resolveOperands(indexInfo, indexTy, result.operands))
1730  return failure();
1731 
1732  if (readOrWrite != "read" && readOrWrite != "write")
1733  return parser.emitError(parser.getNameLoc(),
1734  "rw specifier has to be 'read' or 'write'");
1735  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1736  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1737 
1738  if (cacheType != "data" && cacheType != "instr")
1739  return parser.emitError(parser.getNameLoc(),
1740  "cache type has to be 'data' or 'instr'");
1741 
1742  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1743  parser.getBuilder().getBoolAttr(cacheType == "data"));
1744 
1745  return success();
1746 }
1747 
1748 LogicalResult PrefetchOp::verify() {
1749  if (getNumOperands() != 1 + getMemRefType().getRank())
1750  return emitOpError("too few indices");
1751 
1752  return success();
1753 }
1754 
1755 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1756  SmallVectorImpl<OpFoldResult> &results) {
1757  // prefetch(memrefcast) -> prefetch
1758  return foldMemRefCast(*this);
1759 }
1760 
1761 //===----------------------------------------------------------------------===//
1762 // RankOp
1763 //===----------------------------------------------------------------------===//
1764 
1765 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1766  // Constant fold rank when the rank of the operand is known.
1767  auto type = getOperand().getType();
1768  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1769  if (shapedType && shapedType.hasRank())
1770  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1771  return IntegerAttr();
1772 }
1773 
1774 //===----------------------------------------------------------------------===//
1775 // ReinterpretCastOp
1776 //===----------------------------------------------------------------------===//
1777 
1778 void ReinterpretCastOp::getAsmResultNames(
1779  function_ref<void(Value, StringRef)> setNameFn) {
1780  setNameFn(getResult(), "reinterpret_cast");
1781 }
1782 
1783 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1784 /// `staticSizes` and `staticStrides` are automatically filled with
1785 /// source-memref-rank sentinel values that encode dynamic entries.
1786 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1787  MemRefType resultType, Value source,
1788  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1789  ArrayRef<OpFoldResult> strides,
1790  ArrayRef<NamedAttribute> attrs) {
1791  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1792  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1793  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1794  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1795  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1796  result.addAttributes(attrs);
1797  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1798  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1799  b.getDenseI64ArrayAttr(staticSizes),
1800  b.getDenseI64ArrayAttr(staticStrides));
1801 }
1802 
1803 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1804  Value source, OpFoldResult offset,
1805  ArrayRef<OpFoldResult> sizes,
1806  ArrayRef<OpFoldResult> strides,
1807  ArrayRef<NamedAttribute> attrs) {
1808  auto sourceType = cast<BaseMemRefType>(source.getType());
1809  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1810  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1811  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1812  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1813  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1814  auto stridedLayout = StridedLayoutAttr::get(
1815  b.getContext(), staticOffsets.front(), staticStrides);
1816  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1817  stridedLayout, sourceType.getMemorySpace());
1818  build(b, result, resultType, source, offset, sizes, strides, attrs);
1819 }
1820 
1821 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1822  MemRefType resultType, Value source,
1823  int64_t offset, ArrayRef<int64_t> sizes,
1824  ArrayRef<int64_t> strides,
1825  ArrayRef<NamedAttribute> attrs) {
1826  SmallVector<OpFoldResult> sizeValues =
1827  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1828  return b.getI64IntegerAttr(v);
1829  }));
1830  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1831  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1832  return b.getI64IntegerAttr(v);
1833  }));
1834  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1835  strideValues, attrs);
1836 }
1837 
1838 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1839  MemRefType resultType, Value source, Value offset,
1840  ValueRange sizes, ValueRange strides,
1841  ArrayRef<NamedAttribute> attrs) {
1842  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1843  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1844  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1845  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1846  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1847 }
1848 
1849 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1850 // completed automatically, like we have for subview and extract_slice.
1851 LogicalResult ReinterpretCastOp::verify() {
1852  // The source and result memrefs should be in the same memory space.
1853  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1854  auto resultType = llvm::cast<MemRefType>(getType());
1855  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1856  return emitError("different memory spaces specified for source type ")
1857  << srcType << " and result memref type " << resultType;
1858  if (srcType.getElementType() != resultType.getElementType())
1859  return emitError("different element types specified for source type ")
1860  << srcType << " and result memref type " << resultType;
1861 
1862  // Match sizes in result memref type and in static_sizes attribute.
1863  for (auto [idx, resultSize, expectedSize] :
1864  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1865  if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1866  return emitError("expected result type with size = ")
1867  << (ShapedType::isDynamic(expectedSize)
1868  ? std::string("dynamic")
1869  : std::to_string(expectedSize))
1870  << " instead of " << resultSize << " in dim = " << idx;
1871  }
1872 
1873  // Match offset and strides in static_offset and static_strides attributes. If
1874  // result memref type has no affine map specified, this will assume an
1875  // identity layout.
1876  int64_t resultOffset;
1877  SmallVector<int64_t, 4> resultStrides;
1878  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1879  return emitError("expected result type to have strided layout but found ")
1880  << resultType;
1881 
1882  // Match offset in result memref type and in static_offsets attribute.
1883  int64_t expectedOffset = getStaticOffsets().front();
1884  if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1885  return emitError("expected result type with offset = ")
1886  << (ShapedType::isDynamic(expectedOffset)
1887  ? std::string("dynamic")
1888  : std::to_string(expectedOffset))
1889  << " instead of " << resultOffset;
1890 
1891  // Match strides in result memref type and in static_strides attribute.
1892  for (auto [idx, resultStride, expectedStride] :
1893  llvm::enumerate(resultStrides, getStaticStrides())) {
1894  if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1895  return emitError("expected result type with stride = ")
1896  << (ShapedType::isDynamic(expectedStride)
1897  ? std::string("dynamic")
1898  : std::to_string(expectedStride))
1899  << " instead of " << resultStride << " in dim = " << idx;
1900  }
1901 
1902  return success();
1903 }
1904 
1905 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1906  Value src = getSource();
1907  auto getPrevSrc = [&]() -> Value {
1908  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1909  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1910  return prev.getSource();
1911 
1912  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1913  if (auto prev = src.getDefiningOp<CastOp>())
1914  return prev.getSource();
1915 
1916  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1917  // are 0.
1918  if (auto prev = src.getDefiningOp<SubViewOp>())
1919  if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
1920  return prev.getSource();
1921 
1922  return nullptr;
1923  };
1924 
1925  if (auto prevSrc = getPrevSrc()) {
1926  getSourceMutable().assign(prevSrc);
1927  return getResult();
1928  }
1929 
1930  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1931  if (ShapedType::isStaticShape(getType().getShape()) &&
1932  src.getType() == getType() && getStaticOffsets().front() == 0) {
1933  return src;
1934  }
1935 
1936  return nullptr;
1937 }
1938 
1939 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1941  constifyIndexValues(values, getType().getShape());
1942  return values;
1943 }
1944 
1945 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1946  SmallVector<OpFoldResult> values = getMixedStrides();
1947  SmallVector<int64_t> staticValues;
1948  int64_t unused;
1949  LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
1950  (void)status;
1951  assert(succeeded(status) && "could not get strides from type");
1952  constifyIndexValues(values, staticValues);
1953  return values;
1954 }
1955 
1956 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1957  SmallVector<OpFoldResult> values = getMixedOffsets();
1958  assert(values.size() == 1 &&
1959  "reinterpret_cast must have one and only one offset");
1960  SmallVector<int64_t> staticValues, unused;
1961  int64_t offset;
1962  LogicalResult status = getType().getStridesAndOffset(unused, offset);
1963  (void)status;
1964  assert(succeeded(status) && "could not get offset from type");
1965  staticValues.push_back(offset);
1966  constifyIndexValues(values, staticValues);
1967  return values[0];
1968 }
1969 
1970 namespace {
1971 /// Replace the sequence:
1972 /// ```
1973 /// base, offset, sizes, strides = extract_strided_metadata src
1974 /// dst = reinterpret_cast base to offset, sizes, strides
1975 /// ```
1976 /// With
1977 ///
1978 /// ```
1979 /// dst = memref.cast src
1980 /// ```
1981 ///
1982 /// Note: The cast operation is only inserted when the type of dst and src
1983 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
1984 ///
1985 /// This pattern also matches when the offset, sizes, and strides don't come
1986 /// directly from the `extract_strided_metadata`'s results but it can be
1987 /// statically proven that they would hold the same values.
1988 ///
1989 /// For instance, the following sequence would be replaced:
1990 /// ```
1991 /// base, offset, sizes, strides =
1992 /// extract_strided_metadata memref : memref<3x4xty>
1993 /// dst = reinterpret_cast base to 0, [3, 4], strides
1994 /// ```
1995 /// Because we know (thanks to the type of the input memref) that variable
1996 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
1997 ///
1998 /// Similarly, the following sequence would be replaced:
1999 /// ```
2000 /// c0 = arith.constant 0
2001 /// c4 = arith.constant 4
2002 /// base, offset, sizes, strides =
2003 /// extract_strided_metadata memref : memref<3x4xty>
2004 /// dst = reinterpret_cast base to c0, [3, c4], strides
2005 /// ```
2006 /// Because we know that `offset`and `c0` will hold 0
2007 /// and `c4` will hold 4.
2008 ///
2009 /// If the pattern above does not match, the input of the
2010 /// extract_strided_metadata is always folded into the input of the
2011 /// reinterpret_cast operator. This allows for dead code elimination to get rid
2012 /// of the extract_strided_metadata in some cases.
2013 struct ReinterpretCastOpExtractStridedMetadataFolder
2014  : public OpRewritePattern<ReinterpretCastOp> {
2015 public:
2017 
2018  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2019  PatternRewriter &rewriter) const override {
2020  auto extractStridedMetadata =
2021  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2022  if (!extractStridedMetadata)
2023  return failure();
2024 
2025  // Check if the reinterpret cast reconstructs a memref with the exact same
2026  // properties as the extract strided metadata.
2027  auto isReinterpretCastNoop = [&]() -> bool {
2028  // First, check that the strides are the same.
2029  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2030  op.getConstifiedMixedStrides()))
2031  return false;
2032 
2033  // Second, check the sizes.
2034  if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2035  op.getConstifiedMixedSizes()))
2036  return false;
2037 
2038  // Finally, check the offset.
2039  assert(op.getMixedOffsets().size() == 1 &&
2040  "reinterpret_cast with more than one offset should have been "
2041  "rejected by the verifier");
2042  return extractStridedMetadata.getConstifiedMixedOffset() ==
2043  op.getConstifiedMixedOffset();
2044  };
2045 
2046  if (!isReinterpretCastNoop()) {
2047  // If the extract_strided_metadata / reinterpret_cast pair can't be
2048  // completely folded, then we could fold the input of the
2049  // extract_strided_metadata into the input of the reinterpret_cast
2050  // input. For some cases (e.g., static dimensions) the
2051  // the extract_strided_metadata is eliminated by dead code elimination.
2052  //
2053  // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2054  //
2055  // We can always fold the input of a extract_strided_metadata operator
2056  // to the input of a reinterpret_cast operator, because they point to
2057  // the same memory. Note that the reinterpret_cast does not use the
2058  // layout of its input memref, only its base memory pointer which is
2059  // the same as the base pointer returned by the extract_strided_metadata
2060  // operator and the base pointer of the extract_strided_metadata memref
2061  // input.
2062  rewriter.modifyOpInPlace(op, [&]() {
2063  op.getSourceMutable().assign(extractStridedMetadata.getSource());
2064  });
2065  return success();
2066  }
2067 
2068  // At this point, we know that the back and forth between extract strided
2069  // metadata and reinterpret cast is a noop. However, the final type of the
2070  // reinterpret cast may not be exactly the same as the original memref.
2071  // E.g., it could be changing a dimension from static to dynamic. Check that
2072  // here and add a cast if necessary.
2073  Type srcTy = extractStridedMetadata.getSource().getType();
2074  if (srcTy == op.getResult().getType())
2075  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2076  else
2077  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2078  extractStridedMetadata.getSource());
2079 
2080  return success();
2081  }
2082 };
2083 } // namespace
2084 
2085 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2086  MLIRContext *context) {
2087  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2088 }
2089 
2090 //===----------------------------------------------------------------------===//
2091 // Reassociative reshape ops
2092 //===----------------------------------------------------------------------===//
2093 
2094 void CollapseShapeOp::getAsmResultNames(
2095  function_ref<void(Value, StringRef)> setNameFn) {
2096  setNameFn(getResult(), "collapse_shape");
2097 }
2098 
2099 void ExpandShapeOp::getAsmResultNames(
2100  function_ref<void(Value, StringRef)> setNameFn) {
2101  setNameFn(getResult(), "expand_shape");
2102 }
2103 
2104 LogicalResult ExpandShapeOp::reifyResultShapes(
2105  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2106  reifiedResultShapes = {
2107  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2108  return success();
2109 }
2110 
2111 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2112 /// result and operand. Layout maps are verified separately.
2113 ///
2114 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2115 /// allowed in a reassocation group.
2116 static LogicalResult
2118  ArrayRef<int64_t> expandedShape,
2119  ArrayRef<ReassociationIndices> reassociation,
2120  bool allowMultipleDynamicDimsPerGroup) {
2121  // There must be one reassociation group per collapsed dimension.
2122  if (collapsedShape.size() != reassociation.size())
2123  return op->emitOpError("invalid number of reassociation groups: found ")
2124  << reassociation.size() << ", expected " << collapsedShape.size();
2125 
2126  // The next expected expanded dimension index (while iterating over
2127  // reassociation indices).
2128  int64_t nextDim = 0;
2129  for (const auto &it : llvm::enumerate(reassociation)) {
2130  ReassociationIndices group = it.value();
2131  int64_t collapsedDim = it.index();
2132 
2133  bool foundDynamic = false;
2134  for (int64_t expandedDim : group) {
2135  if (expandedDim != nextDim++)
2136  return op->emitOpError("reassociation indices must be contiguous");
2137 
2138  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2139  return op->emitOpError("reassociation index ")
2140  << expandedDim << " is out of bounds";
2141 
2142  // Check if there are multiple dynamic dims in a reassociation group.
2143  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2144  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2145  return op->emitOpError(
2146  "at most one dimension in a reassociation group may be dynamic");
2147  foundDynamic = true;
2148  }
2149  }
2150 
2151  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2152  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2153  return op->emitOpError("collapsed dim (")
2154  << collapsedDim
2155  << ") must be dynamic if and only if reassociation group is "
2156  "dynamic";
2157 
2158  // If all dims in the reassociation group are static, the size of the
2159  // collapsed dim can be verified.
2160  if (!foundDynamic) {
2161  int64_t groupSize = 1;
2162  for (int64_t expandedDim : group)
2163  groupSize *= expandedShape[expandedDim];
2164  if (groupSize != collapsedShape[collapsedDim])
2165  return op->emitOpError("collapsed dim size (")
2166  << collapsedShape[collapsedDim]
2167  << ") must equal reassociation group size (" << groupSize << ")";
2168  }
2169  }
2170 
2171  if (collapsedShape.empty()) {
2172  // Rank 0: All expanded dimensions must be 1.
2173  for (int64_t d : expandedShape)
2174  if (d != 1)
2175  return op->emitOpError(
2176  "rank 0 memrefs can only be extended/collapsed with/from ones");
2177  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2178  // Rank >= 1: Number of dimensions among all reassociation groups must match
2179  // the result memref rank.
2180  return op->emitOpError("expanded rank (")
2181  << expandedShape.size()
2182  << ") inconsistent with number of reassociation indices (" << nextDim
2183  << ")";
2184  }
2185 
2186  return success();
2187 }
2188 
2189 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2190  return getSymbolLessAffineMaps(getReassociationExprs());
2191 }
2192 
2193 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2195  getReassociationIndices());
2196 }
2197 
2198 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2199  return getSymbolLessAffineMaps(getReassociationExprs());
2200 }
2201 
2202 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2204  getReassociationIndices());
2205 }
2206 
2207 /// Compute the layout map after expanding a given source MemRef type with the
2208 /// specified reassociation indices.
2209 static FailureOr<StridedLayoutAttr>
2210 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2211  ArrayRef<ReassociationIndices> reassociation) {
2212  int64_t srcOffset;
2213  SmallVector<int64_t> srcStrides;
2214  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2215  return failure();
2216  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2217 
2218  // 1-1 mapping between srcStrides and reassociation packs.
2219  // Each srcStride starts with the given value and gets expanded according to
2220  // the proper entries in resultShape.
2221  // Example:
2222  // srcStrides = [10000, 1 , 100 ],
2223  // reassociations = [ [0], [1], [2, 3, 4]],
2224  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2225  // -> For the purpose of stride calculation, the useful sizes are:
2226  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2227  // resultStrides = [10000, 1, 600, 200, 100]
2228  // Note that a stride does not get expanded along the first entry of each
2229  // shape pack.
2230  SmallVector<int64_t> reverseResultStrides;
2231  reverseResultStrides.reserve(resultShape.size());
2232  unsigned shapeIndex = resultShape.size() - 1;
2233  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2234  ReassociationIndices reassoc = std::get<0>(it);
2235  int64_t currentStrideToExpand = std::get<1>(it);
2236  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2237  reverseResultStrides.push_back(currentStrideToExpand);
2238  currentStrideToExpand =
2239  (SaturatedInteger::wrap(currentStrideToExpand) *
2240  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2241  .asInteger();
2242  }
2243  }
2244  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2245  resultStrides.resize(resultShape.size(), 1);
2246  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2247 }
2248 
2249 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2250  MemRefType srcType, ArrayRef<int64_t> resultShape,
2251  ArrayRef<ReassociationIndices> reassociation) {
2252  if (srcType.getLayout().isIdentity()) {
2253  // If the source is contiguous (i.e., no layout map specified), so is the
2254  // result.
2255  MemRefLayoutAttrInterface layout;
2256  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2257  srcType.getMemorySpace());
2258  }
2259 
2260  // Source may not be contiguous. Compute the layout map.
2261  FailureOr<StridedLayoutAttr> computedLayout =
2262  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2263  if (failed(computedLayout))
2264  return failure();
2265  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2266  srcType.getMemorySpace());
2267 }
2268 
2269 FailureOr<SmallVector<OpFoldResult>>
2270 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2271  MemRefType expandedType,
2272  ArrayRef<ReassociationIndices> reassociation,
2273  ArrayRef<OpFoldResult> inputShape) {
2274  std::optional<SmallVector<OpFoldResult>> outputShape =
2275  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2276  inputShape);
2277  if (!outputShape)
2278  return failure();
2279  return *outputShape;
2280 }
2281 
2282 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2283  Type resultType, Value src,
2284  ArrayRef<ReassociationIndices> reassociation,
2285  ArrayRef<OpFoldResult> outputShape) {
2286  auto [staticOutputShape, dynamicOutputShape] =
2288  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2289  getReassociationIndicesAttribute(builder, reassociation),
2290  dynamicOutputShape, staticOutputShape);
2291 }
2292 
2293 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2294  Type resultType, Value src,
2295  ArrayRef<ReassociationIndices> reassociation) {
2296  SmallVector<OpFoldResult> inputShape =
2297  getMixedSizes(builder, result.location, src);
2298  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2299  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2300  builder, result.location, memrefResultTy, reassociation, inputShape);
2301  // Failure of this assertion usually indicates presence of multiple
2302  // dynamic dimensions in the same reassociation group.
2303  assert(succeeded(outputShape) && "unable to infer output shape");
2304  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2305 }
2306 
2307 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2308  ArrayRef<int64_t> resultShape, Value src,
2309  ArrayRef<ReassociationIndices> reassociation) {
2310  // Only ranked memref source values are supported.
2311  auto srcType = llvm::cast<MemRefType>(src.getType());
2312  FailureOr<MemRefType> resultType =
2313  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2314  // Failure of this assertion usually indicates a problem with the source
2315  // type, e.g., could not get strides/offset.
2316  assert(succeeded(resultType) && "could not compute layout");
2317  build(builder, result, *resultType, src, reassociation);
2318 }
2319 
2320 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2321  ArrayRef<int64_t> resultShape, Value src,
2322  ArrayRef<ReassociationIndices> reassociation,
2323  ArrayRef<OpFoldResult> outputShape) {
2324  // Only ranked memref source values are supported.
2325  auto srcType = llvm::cast<MemRefType>(src.getType());
2326  FailureOr<MemRefType> resultType =
2327  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2328  // Failure of this assertion usually indicates a problem with the source
2329  // type, e.g., could not get strides/offset.
2330  assert(succeeded(resultType) && "could not compute layout");
2331  build(builder, result, *resultType, src, reassociation, outputShape);
2332 }
2333 
2334 LogicalResult ExpandShapeOp::verify() {
2335  MemRefType srcType = getSrcType();
2336  MemRefType resultType = getResultType();
2337 
2338  if (srcType.getRank() > resultType.getRank()) {
2339  auto r0 = srcType.getRank();
2340  auto r1 = resultType.getRank();
2341  return emitOpError("has source rank ")
2342  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2343  << r0 << " > " << r1 << ").";
2344  }
2345 
2346  // Verify result shape.
2347  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2348  resultType.getShape(),
2349  getReassociationIndices(),
2350  /*allowMultipleDynamicDimsPerGroup=*/true)))
2351  return failure();
2352 
2353  // Compute expected result type (including layout map).
2354  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2355  srcType, resultType.getShape(), getReassociationIndices());
2356  if (failed(expectedResultType))
2357  return emitOpError("invalid source layout map");
2358 
2359  // Check actual result type.
2360  if (*expectedResultType != resultType)
2361  return emitOpError("expected expanded type to be ")
2362  << *expectedResultType << " but found " << resultType;
2363 
2364  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2365  return emitOpError("expected number of static shape bounds to be equal to "
2366  "the output rank (")
2367  << resultType.getRank() << ") but found "
2368  << getStaticOutputShape().size() << " inputs instead";
2369 
2370  if ((int64_t)getOutputShape().size() !=
2371  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2372  return emitOpError("mismatch in dynamic dims in output_shape and "
2373  "static_output_shape: static_output_shape has ")
2374  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2375  << " dynamic dims while output_shape has " << getOutputShape().size()
2376  << " values";
2377 
2378  // Verify if provided output shapes are in agreement with output type.
2379  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2380  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2381  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2382  if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2383  return emitOpError("invalid output shape provided at pos ") << pos;
2384  }
2385  }
2386 
2387  return success();
2388 }
2389 
2390 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2391  MLIRContext *context) {
2392  results.add<
2395 }
2396 
2397 /// Compute the layout map after collapsing a given source MemRef type with the
2398 /// specified reassociation indices.
2399 ///
2400 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2401 /// not possible to check this by inspecting a MemRefType in the general case.
2402 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2403 /// be valid (and thus accepted by this function) unless `strict = true`.
2404 static FailureOr<StridedLayoutAttr>
2405 computeCollapsedLayoutMap(MemRefType srcType,
2406  ArrayRef<ReassociationIndices> reassociation,
2407  bool strict = false) {
2408  int64_t srcOffset;
2409  SmallVector<int64_t> srcStrides;
2410  auto srcShape = srcType.getShape();
2411  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2412  return failure();
2413 
2414  // The result stride of a reassociation group is the stride of the last entry
2415  // of the reassociation. (TODO: Should be the minimum stride in the
2416  // reassociation because strides are not necessarily sorted. E.g., when using
2417  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2418  // strides are meaningless and could have any arbitrary value.
2419  SmallVector<int64_t> resultStrides;
2420  resultStrides.reserve(reassociation.size());
2421  for (const ReassociationIndices &reassoc : reassociation) {
2422  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2423  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2424  ref = ref.drop_back();
2425  if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2426  resultStrides.push_back(srcStrides[ref.back()]);
2427  } else {
2428  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2429  // the corresponding stride may have to be skipped. (See above comment.)
2430  // Therefore, the result stride cannot be statically determined and must
2431  // be dynamic.
2432  resultStrides.push_back(ShapedType::kDynamic);
2433  }
2434  }
2435 
2436  // Validate that each reassociation group is contiguous.
2437  unsigned resultStrideIndex = resultStrides.size() - 1;
2438  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2439  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2440  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2441  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2442  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2443 
2444  // Both source and result stride must have the same static value. In that
2445  // case, we can be sure, that the dimensions are collapsible (because they
2446  // are contiguous).
2447  // If `strict = false` (default during op verification), we accept cases
2448  // where one or both strides are dynamic. This is best effort: We reject
2449  // ops where obviously non-contiguous dims are collapsed, but accept ops
2450  // where we cannot be sure statically. Such ops may fail at runtime. See
2451  // the op documentation for details.
2452  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2453  if (strict && (stride.saturated || srcStride.saturated))
2454  return failure();
2455 
2456  // Dimensions of size 1 should be skipped, because their strides are
2457  // meaningless and could have any arbitrary value.
2458  if (srcShape[idx - 1] == 1)
2459  continue;
2460 
2461  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2462  return failure();
2463  }
2464  }
2465  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2466 }
2467 
2468 bool CollapseShapeOp::isGuaranteedCollapsible(
2469  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2470  // MemRefs with identity layout are always collapsible.
2471  if (srcType.getLayout().isIdentity())
2472  return true;
2473 
2474  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2475  /*strict=*/true));
2476 }
2477 
2478 MemRefType CollapseShapeOp::computeCollapsedType(
2479  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2480  SmallVector<int64_t> resultShape;
2481  resultShape.reserve(reassociation.size());
2482  for (const ReassociationIndices &group : reassociation) {
2483  auto groupSize = SaturatedInteger::wrap(1);
2484  for (int64_t srcDim : group)
2485  groupSize =
2486  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2487  resultShape.push_back(groupSize.asInteger());
2488  }
2489 
2490  if (srcType.getLayout().isIdentity()) {
2491  // If the source is contiguous (i.e., no layout map specified), so is the
2492  // result.
2493  MemRefLayoutAttrInterface layout;
2494  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2495  srcType.getMemorySpace());
2496  }
2497 
2498  // Source may not be fully contiguous. Compute the layout map.
2499  // Note: Dimensions that are collapsed into a single dim are assumed to be
2500  // contiguous.
2501  FailureOr<StridedLayoutAttr> computedLayout =
2502  computeCollapsedLayoutMap(srcType, reassociation);
2503  assert(succeeded(computedLayout) &&
2504  "invalid source layout map or collapsing non-contiguous dims");
2505  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2506  srcType.getMemorySpace());
2507 }
2508 
2509 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2510  ArrayRef<ReassociationIndices> reassociation,
2511  ArrayRef<NamedAttribute> attrs) {
2512  auto srcType = llvm::cast<MemRefType>(src.getType());
2513  MemRefType resultType =
2514  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2516  getReassociationIndicesAttribute(b, reassociation));
2517  build(b, result, resultType, src, attrs);
2518 }
2519 
2520 LogicalResult CollapseShapeOp::verify() {
2521  MemRefType srcType = getSrcType();
2522  MemRefType resultType = getResultType();
2523 
2524  if (srcType.getRank() < resultType.getRank()) {
2525  auto r0 = srcType.getRank();
2526  auto r1 = resultType.getRank();
2527  return emitOpError("has source rank ")
2528  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2529  << r0 << " < " << r1 << ").";
2530  }
2531 
2532  // Verify result shape.
2533  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2534  srcType.getShape(), getReassociationIndices(),
2535  /*allowMultipleDynamicDimsPerGroup=*/true)))
2536  return failure();
2537 
2538  // Compute expected result type (including layout map).
2539  MemRefType expectedResultType;
2540  if (srcType.getLayout().isIdentity()) {
2541  // If the source is contiguous (i.e., no layout map specified), so is the
2542  // result.
2543  MemRefLayoutAttrInterface layout;
2544  expectedResultType =
2545  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2546  srcType.getMemorySpace());
2547  } else {
2548  // Source may not be fully contiguous. Compute the layout map.
2549  // Note: Dimensions that are collapsed into a single dim are assumed to be
2550  // contiguous.
2551  FailureOr<StridedLayoutAttr> computedLayout =
2552  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2553  if (failed(computedLayout))
2554  return emitOpError(
2555  "invalid source layout map or collapsing non-contiguous dims");
2556  expectedResultType =
2557  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2558  *computedLayout, srcType.getMemorySpace());
2559  }
2560 
2561  if (expectedResultType != resultType)
2562  return emitOpError("expected collapsed type to be ")
2563  << expectedResultType << " but found " << resultType;
2564 
2565  return success();
2566 }
2567 
2569  : public OpRewritePattern<CollapseShapeOp> {
2570 public:
2572 
2573  LogicalResult matchAndRewrite(CollapseShapeOp op,
2574  PatternRewriter &rewriter) const override {
2575  auto cast = op.getOperand().getDefiningOp<CastOp>();
2576  if (!cast)
2577  return failure();
2578 
2579  if (!CastOp::canFoldIntoConsumerOp(cast))
2580  return failure();
2581 
2582  Type newResultType = CollapseShapeOp::computeCollapsedType(
2583  llvm::cast<MemRefType>(cast.getOperand().getType()),
2584  op.getReassociationIndices());
2585 
2586  if (newResultType == op.getResultType()) {
2587  rewriter.modifyOpInPlace(
2588  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2589  } else {
2590  Value newOp = rewriter.create<CollapseShapeOp>(
2591  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2592  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2593  }
2594  return success();
2595  }
2596 };
2597 
2598 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2599  MLIRContext *context) {
2600  results.add<
2602  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2603  memref::DimOp, MemRefType>,
2605 }
2606 
2607 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2608  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2609  adaptor.getOperands());
2610 }
2611 
2612 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2613  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2614  adaptor.getOperands());
2615 }
2616 
2617 //===----------------------------------------------------------------------===//
2618 // ReshapeOp
2619 //===----------------------------------------------------------------------===//
2620 
2621 void ReshapeOp::getAsmResultNames(
2622  function_ref<void(Value, StringRef)> setNameFn) {
2623  setNameFn(getResult(), "reshape");
2624 }
2625 
2626 LogicalResult ReshapeOp::verify() {
2627  Type operandType = getSource().getType();
2628  Type resultType = getResult().getType();
2629 
2630  Type operandElementType =
2631  llvm::cast<ShapedType>(operandType).getElementType();
2632  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2633  if (operandElementType != resultElementType)
2634  return emitOpError("element types of source and destination memref "
2635  "types should be the same");
2636 
2637  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2638  if (!operandMemRefType.getLayout().isIdentity())
2639  return emitOpError("source memref type should have identity affine map");
2640 
2641  int64_t shapeSize =
2642  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2643  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2644  if (resultMemRefType) {
2645  if (!resultMemRefType.getLayout().isIdentity())
2646  return emitOpError("result memref type should have identity affine map");
2647  if (shapeSize == ShapedType::kDynamic)
2648  return emitOpError("cannot use shape operand with dynamic length to "
2649  "reshape to statically-ranked memref type");
2650  if (shapeSize != resultMemRefType.getRank())
2651  return emitOpError(
2652  "length of shape operand differs from the result's memref rank");
2653  }
2654  return success();
2655 }
2656 
2657 //===----------------------------------------------------------------------===//
2658 // StoreOp
2659 //===----------------------------------------------------------------------===//
2660 
2661 LogicalResult StoreOp::verify() {
2662  if (getNumOperands() != 2 + getMemRefType().getRank())
2663  return emitOpError("store index operand count not equal to memref rank");
2664 
2665  return success();
2666 }
2667 
2668 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2669  SmallVectorImpl<OpFoldResult> &results) {
2670  /// store(memrefcast) -> store
2671  return foldMemRefCast(*this, getValueToStore());
2672 }
2673 
2674 //===----------------------------------------------------------------------===//
2675 // SubViewOp
2676 //===----------------------------------------------------------------------===//
2677 
2678 void SubViewOp::getAsmResultNames(
2679  function_ref<void(Value, StringRef)> setNameFn) {
2680  setNameFn(getResult(), "subview");
2681 }
2682 
2683 /// A subview result type can be fully inferred from the source type and the
2684 /// static representation of offsets, sizes and strides. Special sentinels
2685 /// encode the dynamic case.
2686 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2687  ArrayRef<int64_t> staticOffsets,
2688  ArrayRef<int64_t> staticSizes,
2689  ArrayRef<int64_t> staticStrides) {
2690  unsigned rank = sourceMemRefType.getRank();
2691  (void)rank;
2692  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2693  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2694  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2695 
2696  // Extract source offset and strides.
2697  auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2698 
2699  // Compute target offset whose value is:
2700  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2701  int64_t targetOffset = sourceOffset;
2702  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2703  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2704  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2705  SaturatedInteger::wrap(staticOffset) *
2706  SaturatedInteger::wrap(sourceStride))
2707  .asInteger();
2708  }
2709 
2710  // Compute target stride whose value is:
2711  // `sourceStrides_i * staticStrides_i`.
2712  SmallVector<int64_t, 4> targetStrides;
2713  targetStrides.reserve(staticOffsets.size());
2714  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2715  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2716  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2717  SaturatedInteger::wrap(staticStride))
2718  .asInteger());
2719  }
2720 
2721  // The type is now known.
2722  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2723  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2724  targetOffset, targetStrides),
2725  sourceMemRefType.getMemorySpace());
2726 }
2727 
2728 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2729  ArrayRef<OpFoldResult> offsets,
2730  ArrayRef<OpFoldResult> sizes,
2731  ArrayRef<OpFoldResult> strides) {
2732  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2733  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2734  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2735  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2736  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2737  if (!hasValidSizesOffsets(staticOffsets))
2738  return {};
2739  if (!hasValidSizesOffsets(staticSizes))
2740  return {};
2741  if (!hasValidStrides(staticStrides))
2742  return {};
2743  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2744  staticSizes, staticStrides);
2745 }
2746 
2747 MemRefType SubViewOp::inferRankReducedResultType(
2748  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2749  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2750  ArrayRef<int64_t> strides) {
2751  MemRefType inferredType =
2752  inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2753  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2754  "expected ");
2755  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2756  return inferredType;
2757 
2758  // Compute which dimensions are dropped.
2759  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2760  computeRankReductionMask(inferredType.getShape(), resultShape);
2761  assert(dimsToProject.has_value() && "invalid rank reduction");
2762 
2763  // Compute the layout and result type.
2764  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2765  SmallVector<int64_t> rankReducedStrides;
2766  rankReducedStrides.reserve(resultShape.size());
2767  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2768  if (!dimsToProject->contains(idx))
2769  rankReducedStrides.push_back(value);
2770  }
2771  return MemRefType::get(resultShape, inferredType.getElementType(),
2772  StridedLayoutAttr::get(inferredLayout.getContext(),
2773  inferredLayout.getOffset(),
2774  rankReducedStrides),
2775  inferredType.getMemorySpace());
2776 }
2777 
2778 MemRefType SubViewOp::inferRankReducedResultType(
2779  ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2781  ArrayRef<OpFoldResult> strides) {
2782  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2783  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2784  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2785  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2786  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2787  return SubViewOp::inferRankReducedResultType(
2788  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2789  staticStrides);
2790 }
2791 
2792 // Build a SubViewOp with mixed static and dynamic entries and custom result
2793 // type. If the type passed is nullptr, it is inferred.
2794 void SubViewOp::build(OpBuilder &b, OperationState &result,
2795  MemRefType resultType, Value source,
2796  ArrayRef<OpFoldResult> offsets,
2797  ArrayRef<OpFoldResult> sizes,
2798  ArrayRef<OpFoldResult> strides,
2799  ArrayRef<NamedAttribute> attrs) {
2800  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2801  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2802  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2803  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2804  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2805  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2806  // Structuring implementation this way avoids duplication between builders.
2807  if (!resultType) {
2808  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2809  staticSizes, staticStrides);
2810  }
2811  result.addAttributes(attrs);
2812  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2813  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2814  b.getDenseI64ArrayAttr(staticSizes),
2815  b.getDenseI64ArrayAttr(staticStrides));
2816 }
2817 
2818 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2819 // type.
2820 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2821  ArrayRef<OpFoldResult> offsets,
2822  ArrayRef<OpFoldResult> sizes,
2823  ArrayRef<OpFoldResult> strides,
2824  ArrayRef<NamedAttribute> attrs) {
2825  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2826 }
2827 
2828 // Build a SubViewOp with static entries and inferred result type.
2829 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2830  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2831  ArrayRef<int64_t> strides,
2832  ArrayRef<NamedAttribute> attrs) {
2833  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2834  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2835  return b.getI64IntegerAttr(v);
2836  }));
2837  SmallVector<OpFoldResult> sizeValues =
2838  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2839  return b.getI64IntegerAttr(v);
2840  }));
2841  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2842  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2843  return b.getI64IntegerAttr(v);
2844  }));
2845  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2846 }
2847 
2848 // Build a SubViewOp with dynamic entries and custom result type. If the
2849 // type passed is nullptr, it is inferred.
2850 void SubViewOp::build(OpBuilder &b, OperationState &result,
2851  MemRefType resultType, Value source,
2852  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2853  ArrayRef<int64_t> strides,
2854  ArrayRef<NamedAttribute> attrs) {
2855  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2856  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2857  return b.getI64IntegerAttr(v);
2858  }));
2859  SmallVector<OpFoldResult> sizeValues =
2860  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2861  return b.getI64IntegerAttr(v);
2862  }));
2863  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2864  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2865  return b.getI64IntegerAttr(v);
2866  }));
2867  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2868  attrs);
2869 }
2870 
2871 // Build a SubViewOp with dynamic entries and custom result type. If the type
2872 // passed is nullptr, it is inferred.
2873 void SubViewOp::build(OpBuilder &b, OperationState &result,
2874  MemRefType resultType, Value source, ValueRange offsets,
2875  ValueRange sizes, ValueRange strides,
2876  ArrayRef<NamedAttribute> attrs) {
2877  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2878  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2879  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2880  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2881  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2882  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2883  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2884 }
2885 
2886 // Build a SubViewOp with dynamic entries and inferred result type.
2887 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2888  ValueRange offsets, ValueRange sizes, ValueRange strides,
2889  ArrayRef<NamedAttribute> attrs) {
2890  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2891 }
2892 
2893 /// For ViewLikeOpInterface.
2894 Value SubViewOp::getViewSource() { return getSource(); }
2895 
2896 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2897 /// static value).
2898 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2899  int64_t t1Offset, t2Offset;
2900  SmallVector<int64_t> t1Strides, t2Strides;
2901  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2902  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2903  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2904 }
2905 
2906 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2907 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2908 /// marked as dropped in `droppedDims`.
2909 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2910  const llvm::SmallBitVector &droppedDims) {
2911  assert(size_t(t1.getRank()) == droppedDims.size() &&
2912  "incorrect number of bits");
2913  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2914  "incorrect number of dropped dims");
2915  int64_t t1Offset, t2Offset;
2916  SmallVector<int64_t> t1Strides, t2Strides;
2917  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2918  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2919  if (failed(res1) || failed(res2))
2920  return false;
2921  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2922  if (droppedDims[i])
2923  continue;
2924  if (t1Strides[i] != t2Strides[j])
2925  return false;
2926  ++j;
2927  }
2928  return true;
2929 }
2930 
2932  SubViewOp op, Type expectedType) {
2933  auto memrefType = llvm::cast<ShapedType>(expectedType);
2934  switch (result) {
2936  return success();
2938  return op->emitError("expected result rank to be smaller or equal to ")
2939  << "the source rank, but got " << op.getType();
2941  return op->emitError("expected result type to be ")
2942  << expectedType
2943  << " or a rank-reduced version. (mismatch of result sizes), but got "
2944  << op.getType();
2946  return op->emitError("expected result element type to be ")
2947  << memrefType.getElementType() << ", but got " << op.getType();
2949  return op->emitError(
2950  "expected result and source memory spaces to match, but got ")
2951  << op.getType();
2953  return op->emitError("expected result type to be ")
2954  << expectedType
2955  << " or a rank-reduced version. (mismatch of result layout), but "
2956  "got "
2957  << op.getType();
2958  }
2959  llvm_unreachable("unexpected subview verification result");
2960 }
2961 
2962 /// Verifier for SubViewOp.
2963 LogicalResult SubViewOp::verify() {
2964  MemRefType baseType = getSourceType();
2965  MemRefType subViewType = getType();
2966  ArrayRef<int64_t> staticOffsets = getStaticOffsets();
2967  ArrayRef<int64_t> staticSizes = getStaticSizes();
2968  ArrayRef<int64_t> staticStrides = getStaticStrides();
2969 
2970  // The base memref and the view memref should be in the same memory space.
2971  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2972  return emitError("different memory spaces specified for base memref "
2973  "type ")
2974  << baseType << " and subview memref type " << subViewType;
2975 
2976  // Verify that the base memref type has a strided layout map.
2977  if (!baseType.isStrided())
2978  return emitError("base type ") << baseType << " is not strided";
2979 
2980  // Compute the expected result type, assuming that there are no rank
2981  // reductions.
2982  MemRefType expectedType = SubViewOp::inferResultType(
2983  baseType, staticOffsets, staticSizes, staticStrides);
2984 
2985  // Verify all properties of a shaped type: rank, element type and dimension
2986  // sizes. This takes into account potential rank reductions.
2987  auto shapedTypeVerification = isRankReducedType(
2988  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2989  if (shapedTypeVerification != SliceVerificationResult::Success)
2990  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2991 
2992  // Make sure that the memory space did not change.
2993  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2995  *this, expectedType);
2996 
2997  // Verify the offset of the layout map.
2998  if (!haveCompatibleOffsets(expectedType, subViewType))
3000  *this, expectedType);
3001 
3002  // The only thing that's left to verify now are the strides. First, compute
3003  // the unused dimensions due to rank reductions. We have to look at sizes and
3004  // strides to decide which dimensions were dropped. This function also
3005  // partially verifies strides in case of rank reductions.
3006  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3007  getMixedSizes());
3008  if (failed(unusedDims))
3010  *this, expectedType);
3011 
3012  // Strides must match.
3013  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3015  *this, expectedType);
3016 
3017  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3018  // to the base memref.
3019  SliceBoundsVerificationResult boundsResult =
3020  verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3021  staticStrides, /*generateErrorMessage=*/true);
3022  if (!boundsResult.isValid)
3023  return getOperation()->emitError(boundsResult.errorMessage);
3024 
3025  return success();
3026 }
3027 
3028 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
3029  return os << "range " << range.offset << ":" << range.size << ":"
3030  << range.stride;
3031 }
3032 
3033 /// Return the list of Range (i.e. offset, size, stride). Each Range
3034 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3035 /// with `b` at location `loc`.
3036 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3037  OpBuilder &b, Location loc) {
3038  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3039  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3040  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3042  unsigned rank = ranks[0];
3043  res.reserve(rank);
3044  for (unsigned idx = 0; idx < rank; ++idx) {
3045  Value offset =
3046  op.isDynamicOffset(idx)
3047  ? op.getDynamicOffset(idx)
3048  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
3049  Value size =
3050  op.isDynamicSize(idx)
3051  ? op.getDynamicSize(idx)
3052  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
3053  Value stride =
3054  op.isDynamicStride(idx)
3055  ? op.getDynamicStride(idx)
3056  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
3057  res.emplace_back(Range{offset, size, stride});
3058  }
3059  return res;
3060 }
3061 
3062 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3063 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3064 /// the rank of the inferred result type if `currentResultType` is lower rank
3065 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3066 /// together with the result type. In this case, it is important to compute
3067 /// the dropped dimensions using `currentSourceType` whose strides align with
3068 /// `currentResultType`.
3070  MemRefType currentResultType, MemRefType currentSourceType,
3071  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3072  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3073  MemRefType nonRankReducedType = SubViewOp::inferResultType(
3074  sourceType, mixedOffsets, mixedSizes, mixedStrides);
3075  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3076  currentSourceType, currentResultType, mixedSizes);
3077  if (failed(unusedDims))
3078  return nullptr;
3079 
3080  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3081  SmallVector<int64_t> shape, strides;
3082  unsigned numDimsAfterReduction =
3083  nonRankReducedType.getRank() - unusedDims->count();
3084  shape.reserve(numDimsAfterReduction);
3085  strides.reserve(numDimsAfterReduction);
3086  for (const auto &[idx, size, stride] :
3087  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3088  nonRankReducedType.getShape(), layout.getStrides())) {
3089  if (unusedDims->test(idx))
3090  continue;
3091  shape.push_back(size);
3092  strides.push_back(stride);
3093  }
3094 
3095  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3096  StridedLayoutAttr::get(sourceType.getContext(),
3097  layout.getOffset(), strides),
3098  nonRankReducedType.getMemorySpace());
3099 }
3100 
3102  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3103  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3104  unsigned rank = memrefType.getRank();
3105  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3106  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3107  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3108  MemRefType targetType = SubViewOp::inferRankReducedResultType(
3109  targetShape, memrefType, offsets, sizes, strides);
3110  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3111  sizes, strides);
3112 }
3113 
3114 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3115  Value value,
3116  ArrayRef<int64_t> desiredShape) {
3117  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3118  assert(sourceMemrefType && "not a ranked memref type");
3119  auto sourceShape = sourceMemrefType.getShape();
3120  if (sourceShape.equals(desiredShape))
3121  return value;
3122  auto maybeRankReductionMask =
3123  mlir::computeRankReductionMask(sourceShape, desiredShape);
3124  if (!maybeRankReductionMask)
3125  return failure();
3126  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3127 }
3128 
3129 /// Helper method to check if a `subview` operation is trivially a no-op. This
3130 /// is the case if the all offsets are zero, all strides are 1, and the source
3131 /// shape is same as the size of the subview. In such cases, the subview can
3132 /// be folded into its source.
3133 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3134  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3135  return false;
3136 
3137  auto mixedOffsets = subViewOp.getMixedOffsets();
3138  auto mixedSizes = subViewOp.getMixedSizes();
3139  auto mixedStrides = subViewOp.getMixedStrides();
3140 
3141  // Check offsets are zero.
3142  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3143  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3144  return !intValue || intValue.value() != 0;
3145  }))
3146  return false;
3147 
3148  // Check strides are one.
3149  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3150  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3151  return !intValue || intValue.value() != 1;
3152  }))
3153  return false;
3154 
3155  // Check all size values are static and matches the (static) source shape.
3156  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3157  for (const auto &size : llvm::enumerate(mixedSizes)) {
3158  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3159  if (!intValue || *intValue != sourceShape[size.index()])
3160  return false;
3161  }
3162  // All conditions met. The `SubViewOp` is foldable as a no-op.
3163  return true;
3164 }
3165 
3166 namespace {
3167 /// Pattern to rewrite a subview op with MemRefCast arguments.
3168 /// This essentially pushes memref.cast past its consuming subview when
3169 /// `canFoldIntoConsumerOp` is true.
3170 ///
3171 /// Example:
3172 /// ```
3173 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3174 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3175 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3176 /// ```
3177 /// is rewritten into:
3178 /// ```
3179 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3180 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3181 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3182 /// ```
3183 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3184 public:
3186 
3187  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3188  PatternRewriter &rewriter) const override {
3189  // Any constant operand, just return to let SubViewOpConstantFolder kick
3190  // in.
3191  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3192  return matchPattern(operand, matchConstantIndex());
3193  }))
3194  return failure();
3195 
3196  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3197  if (!castOp)
3198  return failure();
3199 
3200  if (!CastOp::canFoldIntoConsumerOp(castOp))
3201  return failure();
3202 
3203  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3204  // the MemRefCastOp source operand type to infer the result type and the
3205  // current SubViewOp source operand type to compute the dropped dimensions
3206  // if the operation is rank-reducing.
3207  auto resultType = getCanonicalSubViewResultType(
3208  subViewOp.getType(), subViewOp.getSourceType(),
3209  llvm::cast<MemRefType>(castOp.getSource().getType()),
3210  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3211  subViewOp.getMixedStrides());
3212  if (!resultType)
3213  return failure();
3214 
3215  Value newSubView = rewriter.create<SubViewOp>(
3216  subViewOp.getLoc(), resultType, castOp.getSource(),
3217  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3218  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3219  subViewOp.getStaticStrides());
3220  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3221  newSubView);
3222  return success();
3223  }
3224 };
3225 
3226 /// Canonicalize subview ops that are no-ops. When the source shape is not
3227 /// same as a result shape due to use of `affine_map`.
3228 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3229 public:
3231 
3232  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3233  PatternRewriter &rewriter) const override {
3234  if (!isTrivialSubViewOp(subViewOp))
3235  return failure();
3236  if (subViewOp.getSourceType() == subViewOp.getType()) {
3237  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3238  return success();
3239  }
3240  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3241  subViewOp.getSource());
3242  return success();
3243  }
3244 };
3245 } // namespace
3246 
3247 /// Return the canonical type of the result of a subview.
3249  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3250  ArrayRef<OpFoldResult> mixedSizes,
3251  ArrayRef<OpFoldResult> mixedStrides) {
3252  // Infer a memref type without taking into account any rank reductions.
3253  MemRefType resTy = SubViewOp::inferResultType(
3254  op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3255  if (!resTy)
3256  return {};
3257  MemRefType nonReducedType = resTy;
3258 
3259  // Directly return the non-rank reduced type if there are no dropped dims.
3260  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3261  if (droppedDims.none())
3262  return nonReducedType;
3263 
3264  // Take the strides and offset from the non-rank reduced type.
3265  auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3266 
3267  // Drop dims from shape and strides.
3268  SmallVector<int64_t> targetShape;
3269  SmallVector<int64_t> targetStrides;
3270  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3271  if (droppedDims.test(i))
3272  continue;
3273  targetStrides.push_back(nonReducedStrides[i]);
3274  targetShape.push_back(nonReducedType.getDimSize(i));
3275  }
3276 
3277  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3278  StridedLayoutAttr::get(nonReducedType.getContext(),
3279  offset, targetStrides),
3280  nonReducedType.getMemorySpace());
3281  }
3282 };
3283 
3284 /// A canonicalizer wrapper to replace SubViewOps.
3286  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3287  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3288  }
3289 };
3290 
3291 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3292  MLIRContext *context) {
3293  results
3296  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3297 }
3298 
3299 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3300  MemRefType sourceMemrefType = getSource().getType();
3301  MemRefType resultMemrefType = getResult().getType();
3302  auto resultLayout =
3303  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3304 
3305  if (resultMemrefType == sourceMemrefType &&
3306  resultMemrefType.hasStaticShape() &&
3307  (!resultLayout || resultLayout.hasStaticLayout())) {
3308  return getViewSource();
3309  }
3310 
3311  // Fold subview(subview(x)), where both subviews have the same size and the
3312  // second subview's offsets are all zero. (I.e., the second subview is a
3313  // no-op.)
3314  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3315  auto srcSizes = srcSubview.getMixedSizes();
3316  auto sizes = getMixedSizes();
3317  auto offsets = getMixedOffsets();
3318  bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3319  auto strides = getMixedStrides();
3320  bool allStridesOne = llvm::all_of(strides, isOneInteger);
3321  bool allSizesSame = llvm::equal(sizes, srcSizes);
3322  if (allOffsetsZero && allStridesOne && allSizesSame &&
3323  resultMemrefType == sourceMemrefType)
3324  return getViewSource();
3325  }
3326 
3327  return {};
3328 }
3329 
3330 //===----------------------------------------------------------------------===//
3331 // TransposeOp
3332 //===----------------------------------------------------------------------===//
3333 
3334 void TransposeOp::getAsmResultNames(
3335  function_ref<void(Value, StringRef)> setNameFn) {
3336  setNameFn(getResult(), "transpose");
3337 }
3338 
3339 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3340 static MemRefType inferTransposeResultType(MemRefType memRefType,
3341  AffineMap permutationMap) {
3342  auto originalSizes = memRefType.getShape();
3343  auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3344  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3345 
3346  // Compute permuted sizes and strides.
3347  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3348  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3349 
3350  return MemRefType::Builder(memRefType)
3351  .setShape(sizes)
3352  .setLayout(
3353  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3354 }
3355 
3356 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3357  AffineMapAttr permutation,
3358  ArrayRef<NamedAttribute> attrs) {
3359  auto permutationMap = permutation.getValue();
3360  assert(permutationMap);
3361 
3362  auto memRefType = llvm::cast<MemRefType>(in.getType());
3363  // Compute result type.
3364  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3365 
3366  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3367  build(b, result, resultType, in, attrs);
3368 }
3369 
3370 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3372  p << " " << getIn() << " " << getPermutation();
3373  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3374  p << " : " << getIn().getType() << " to " << getType();
3375 }
3376 
3377 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3379  AffineMap permutation;
3380  MemRefType srcType, dstType;
3381  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3382  parser.parseOptionalAttrDict(result.attributes) ||
3383  parser.parseColonType(srcType) ||
3384  parser.resolveOperand(in, srcType, result.operands) ||
3385  parser.parseKeywordType("to", dstType) ||
3386  parser.addTypeToList(dstType, result.types))
3387  return failure();
3388 
3389  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3390  AffineMapAttr::get(permutation));
3391  return success();
3392 }
3393 
3394 LogicalResult TransposeOp::verify() {
3395  if (!getPermutation().isPermutation())
3396  return emitOpError("expected a permutation map");
3397  if (getPermutation().getNumDims() != getIn().getType().getRank())
3398  return emitOpError("expected a permutation map of same rank as the input");
3399 
3400  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3401  auto resultType = llvm::cast<MemRefType>(getType());
3402  auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3403  .canonicalizeStridedLayout();
3404 
3405  if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3406  return emitOpError("result type ")
3407  << resultType
3408  << " is not equivalent to the canonical transposed input type "
3409  << canonicalResultType;
3410  return success();
3411 }
3412 
3413 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3414  // First check for identity permutation, we can fold it away if input and
3415  // result types are identical already.
3416  if (getPermutation().isIdentity() && getType() == getIn().getType())
3417  return getIn();
3418  // Fold two consecutive memref.transpose Ops into one by composing their
3419  // permutation maps.
3420  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3421  AffineMap composedPermutation =
3422  getPermutation().compose(otherTransposeOp.getPermutation());
3423  getInMutable().assign(otherTransposeOp.getIn());
3424  setPermutation(composedPermutation);
3425  return getResult();
3426  }
3427  return {};
3428 }
3429 
3430 //===----------------------------------------------------------------------===//
3431 // ViewOp
3432 //===----------------------------------------------------------------------===//
3433 
3434 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3435  setNameFn(getResult(), "view");
3436 }
3437 
3438 LogicalResult ViewOp::verify() {
3439  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3440  auto viewType = getType();
3441 
3442  // The base memref should have identity layout map (or none).
3443  if (!baseType.getLayout().isIdentity())
3444  return emitError("unsupported map for base memref type ") << baseType;
3445 
3446  // The result memref should have identity layout map (or none).
3447  if (!viewType.getLayout().isIdentity())
3448  return emitError("unsupported map for result memref type ") << viewType;
3449 
3450  // The base memref and the view memref should be in the same memory space.
3451  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3452  return emitError("different memory spaces specified for base memref "
3453  "type ")
3454  << baseType << " and view memref type " << viewType;
3455 
3456  // Verify that we have the correct number of sizes for the result type.
3457  unsigned numDynamicDims = viewType.getNumDynamicDims();
3458  if (getSizes().size() != numDynamicDims)
3459  return emitError("incorrect number of size operands for type ") << viewType;
3460 
3461  return success();
3462 }
3463 
3464 Value ViewOp::getViewSource() { return getSource(); }
3465 
3466 OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3467  MemRefType sourceMemrefType = getSource().getType();
3468  MemRefType resultMemrefType = getResult().getType();
3469 
3470  if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3471  return getViewSource();
3472 
3473  return {};
3474 }
3475 
3476 namespace {
3477 
3478 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3480 
3481  LogicalResult matchAndRewrite(ViewOp viewOp,
3482  PatternRewriter &rewriter) const override {
3483  // Return if none of the operands are constants.
3484  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3485  return matchPattern(operand, matchConstantIndex());
3486  }))
3487  return failure();
3488 
3489  // Get result memref type.
3490  auto memrefType = viewOp.getType();
3491 
3492  // Get offset from old memref view type 'memRefType'.
3493  int64_t oldOffset;
3494  SmallVector<int64_t, 4> oldStrides;
3495  if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3496  return failure();
3497  assert(oldOffset == 0 && "Expected 0 offset");
3498 
3499  SmallVector<Value, 4> newOperands;
3500 
3501  // Offset cannot be folded into result type.
3502 
3503  // Fold any dynamic dim operands which are produced by a constant.
3504  SmallVector<int64_t, 4> newShapeConstants;
3505  newShapeConstants.reserve(memrefType.getRank());
3506 
3507  unsigned dynamicDimPos = 0;
3508  unsigned rank = memrefType.getRank();
3509  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3510  int64_t dimSize = memrefType.getDimSize(dim);
3511  // If this is already static dimension, keep it.
3512  if (ShapedType::isStatic(dimSize)) {
3513  newShapeConstants.push_back(dimSize);
3514  continue;
3515  }
3516  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3517  if (auto constantIndexOp =
3518  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3519  // Dynamic shape dimension will be folded.
3520  newShapeConstants.push_back(constantIndexOp.value());
3521  } else {
3522  // Dynamic shape dimension not folded; copy operand from old memref.
3523  newShapeConstants.push_back(dimSize);
3524  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3525  }
3526  dynamicDimPos++;
3527  }
3528 
3529  // Create new memref type with constant folded dims.
3530  MemRefType newMemRefType =
3531  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3532  // Nothing new, don't fold.
3533  if (newMemRefType == memrefType)
3534  return failure();
3535 
3536  // Create new ViewOp.
3537  auto newViewOp = rewriter.create<ViewOp>(
3538  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3539  viewOp.getByteShift(), newOperands);
3540  // Insert a cast so we have the same type as the old memref type.
3541  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3542  return success();
3543  }
3544 };
3545 
3546 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3548 
3549  LogicalResult matchAndRewrite(ViewOp viewOp,
3550  PatternRewriter &rewriter) const override {
3551  Value memrefOperand = viewOp.getOperand(0);
3552  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3553  if (!memrefCastOp)
3554  return failure();
3555  Value allocOperand = memrefCastOp.getOperand();
3556  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3557  if (!allocOp)
3558  return failure();
3559  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3560  viewOp.getByteShift(),
3561  viewOp.getSizes());
3562  return success();
3563  }
3564 };
3565 
3566 } // namespace
3567 
3568 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3569  MLIRContext *context) {
3570  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3571 }
3572 
3573 //===----------------------------------------------------------------------===//
3574 // AtomicRMWOp
3575 //===----------------------------------------------------------------------===//
3576 
3577 LogicalResult AtomicRMWOp::verify() {
3578  if (getMemRefType().getRank() != getNumOperands() - 2)
3579  return emitOpError(
3580  "expects the number of subscripts to be equal to memref rank");
3581  switch (getKind()) {
3582  case arith::AtomicRMWKind::addf:
3583  case arith::AtomicRMWKind::maximumf:
3584  case arith::AtomicRMWKind::minimumf:
3585  case arith::AtomicRMWKind::mulf:
3586  if (!llvm::isa<FloatType>(getValue().getType()))
3587  return emitOpError() << "with kind '"
3588  << arith::stringifyAtomicRMWKind(getKind())
3589  << "' expects a floating-point type";
3590  break;
3591  case arith::AtomicRMWKind::addi:
3592  case arith::AtomicRMWKind::maxs:
3593  case arith::AtomicRMWKind::maxu:
3594  case arith::AtomicRMWKind::mins:
3595  case arith::AtomicRMWKind::minu:
3596  case arith::AtomicRMWKind::muli:
3597  case arith::AtomicRMWKind::ori:
3598  case arith::AtomicRMWKind::andi:
3599  if (!llvm::isa<IntegerType>(getValue().getType()))
3600  return emitOpError() << "with kind '"
3601  << arith::stringifyAtomicRMWKind(getKind())
3602  << "' expects an integer type";
3603  break;
3604  default:
3605  break;
3606  }
3607  return success();
3608 }
3609 
3610 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3611  /// atomicrmw(memrefcast) -> atomicrmw
3612  if (succeeded(foldMemRefCast(*this, getValue())))
3613  return getResult();
3614  return OpFoldResult();
3615 }
3616 
3617 //===----------------------------------------------------------------------===//
3618 // TableGen'd op method definitions
3619 //===----------------------------------------------------------------------===//
3620 
3621 #define GET_OP_CLASSES
3622 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition: MemRefOps.cpp:96
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1509
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2117
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:377
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3340
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:358
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2898
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1347
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
Definition: MemRefOps.cpp:2931
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3069
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1523
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:887
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3133
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:400
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2909
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:872
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2210
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2405
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:129
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:140
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:250
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:193
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:226
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:810
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:97
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:68
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3101
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3036
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:448
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:451
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:408
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:411
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2573
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3285
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3286
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3248
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3249
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.