MLIR  20.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 /// Materialize a single constant operation from a given attribute value with
30 /// the desired resultant type.
32  Attribute value, Type type,
33  Location loc) {
34  return arith::ConstantOp::materialize(builder, value, type, loc);
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Common canonicalization pattern support logic
39 //===----------------------------------------------------------------------===//
40 
41 /// This is a common class used for patterns of the form
42 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
43 /// into the root operation directly.
44 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
45  bool folded = false;
46  for (OpOperand &operand : op->getOpOperands()) {
47  auto cast = operand.get().getDefiningOp<CastOp>();
48  if (cast && operand.get() != inner &&
49  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50  operand.set(cast.getOperand());
51  folded = true;
52  }
53  }
54  return success(folded);
55 }
56 
57 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
58 /// type.
60  if (auto memref = llvm::dyn_cast<MemRefType>(type))
61  return RankedTensorType::get(memref.getShape(), memref.getElementType());
62  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
63  return UnrankedTensorType::get(memref.getElementType());
64  return NoneType::get(type.getContext());
65 }
66 
68  int64_t dim) {
69  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that infers the constant values from a list of \p values,
91 /// a \p memRefTy, and another helper function \p getAttributes.
92 /// The inferred constant values replace the related `OpFoldResult` in
93 /// \p values.
94 ///
95 /// \note This function shouldn't be used directly, instead, use the
96 /// `getConstifiedMixedXXX` methods from the related operations.
97 ///
98 /// \p getAttributes retuns a list of potentially constant values, as determined
99 /// by \p isDynamic, from the given \p memRefTy. The returned list must have as
100 /// many elements as \p values or be empty.
101 ///
102 /// E.g., consider the following example:
103 /// ```
104 /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
105 /// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
106 /// ```
107 /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
108 /// Now using this helper function with:
109 /// - `values == [2, %dyn_stride]`,
110 /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
111 /// - `getAttributes == getConstantStrides` (i.e., a wrapper around
112 /// `getStridesAndOffset`), and
113 /// - `isDynamic == ShapedType::isDynamic`
114 /// Will yield: `values == [2, 1]`
116  SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
117  MLIRContext *ctxt,
118  llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
119  llvm::function_ref<bool(int64_t)> isDynamic) {
120  SmallVector<int64_t> constValues = getAttributes(memRefTy);
121  Builder builder(ctxt);
122  for (const auto &it : llvm::enumerate(constValues)) {
123  int64_t constValue = it.value();
124  if (!isDynamic(constValue))
125  values[it.index()] = builder.getIndexAttr(constValue);
126  }
127  for (OpFoldResult &ofr : values) {
128  if (ofr.is<Attribute>()) {
129  // FIXME: We shouldn't need to do that, but right now, the static indices
130  // are created with the wrong type: `i64` instead of `index`.
131  // As a result, if we were to keep the attribute as is, we may fail to see
132  // that two attributes are equal because one would have the i64 type and
133  // the other the index type.
134  // The alternative would be to create constant indices with getI64Attr in
135  // this and the previous loop, but it doesn't logically make sense (we are
136  // dealing with indices here) and would only strenghten the inconsistency
137  // around how static indices are created (some places use getI64Attr,
138  // others use getIndexAttr).
139  // The workaround here is to stick to the IndexAttr type for all the
140  // values, hence we recreate the attribute even when it is already static
141  // to make sure the type is consistent.
142  ofr = builder.getIndexAttr(
143  llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt());
144  continue;
145  }
146  std::optional<int64_t> maybeConstant =
147  getConstantIntValue(ofr.get<Value>());
148  if (maybeConstant)
149  ofr = builder.getIndexAttr(*maybeConstant);
150  }
151 }
152 
153 /// Wrapper around `getShape` that conforms to the function signature
154 /// expected for `getAttributes` in `constifyIndexValues`.
155 static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
156  ArrayRef<int64_t> sizes = memRefTy.getShape();
157  return SmallVector<int64_t>(sizes);
158 }
159 
160 /// Wrapper around `getStridesAndOffset` that returns only the offset and
161 /// conforms to the function signature expected for `getAttributes` in
162 /// `constifyIndexValues`.
163 static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
164  SmallVector<int64_t> strides;
165  int64_t offset;
166  LogicalResult hasStaticInformation =
167  getStridesAndOffset(memrefType, strides, offset);
168  if (failed(hasStaticInformation))
169  return SmallVector<int64_t>();
170  return SmallVector<int64_t>(1, offset);
171 }
172 
173 /// Wrapper around `getStridesAndOffset` that returns only the strides and
174 /// conforms to the function signature expected for `getAttributes` in
175 /// `constifyIndexValues`.
176 static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
177  SmallVector<int64_t> strides;
178  int64_t offset;
179  LogicalResult hasStaticInformation =
180  getStridesAndOffset(memrefType, strides, offset);
181  if (failed(hasStaticInformation))
182  return SmallVector<int64_t>();
183  return strides;
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // AllocOp / AllocaOp
188 //===----------------------------------------------------------------------===//
189 
190 void AllocOp::getAsmResultNames(
191  function_ref<void(Value, StringRef)> setNameFn) {
192  setNameFn(getResult(), "alloc");
193 }
194 
195 void AllocaOp::getAsmResultNames(
196  function_ref<void(Value, StringRef)> setNameFn) {
197  setNameFn(getResult(), "alloca");
198 }
199 
200 template <typename AllocLikeOp>
201 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
202  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
203  "applies to only alloc or alloca");
204  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
205  if (!memRefType)
206  return op.emitOpError("result must be a memref");
207 
208  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
209  return op.emitOpError("dimension operand count does not equal memref "
210  "dynamic dimension count");
211 
212  unsigned numSymbols = 0;
213  if (!memRefType.getLayout().isIdentity())
214  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
215  if (op.getSymbolOperands().size() != numSymbols)
216  return op.emitOpError("symbol operand count does not equal memref symbol "
217  "count: expected ")
218  << numSymbols << ", got " << op.getSymbolOperands().size();
219 
220  return success();
221 }
222 
223 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
224 
225 LogicalResult AllocaOp::verify() {
226  // An alloca op needs to have an ancestor with an allocation scope trait.
227  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
228  return emitOpError(
229  "requires an ancestor op with AutomaticAllocationScope trait");
230 
231  return verifyAllocLikeOp(*this);
232 }
233 
234 namespace {
235 /// Fold constant dimensions into an alloc like operation.
236 template <typename AllocLikeOp>
237 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
239 
240  LogicalResult matchAndRewrite(AllocLikeOp alloc,
241  PatternRewriter &rewriter) const override {
242  // Check to see if any dimensions operands are constants. If so, we can
243  // substitute and drop them.
244  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
245  APInt constSizeArg;
246  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
247  return false;
248  return constSizeArg.isNonNegative();
249  }))
250  return failure();
251 
252  auto memrefType = alloc.getType();
253 
254  // Ok, we have one or more constant operands. Collect the non-constant ones
255  // and keep track of the resultant memref type to build.
256  SmallVector<int64_t, 4> newShapeConstants;
257  newShapeConstants.reserve(memrefType.getRank());
258  SmallVector<Value, 4> dynamicSizes;
259 
260  unsigned dynamicDimPos = 0;
261  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
262  int64_t dimSize = memrefType.getDimSize(dim);
263  // If this is already static dimension, keep it.
264  if (!ShapedType::isDynamic(dimSize)) {
265  newShapeConstants.push_back(dimSize);
266  continue;
267  }
268  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
269  APInt constSizeArg;
270  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
271  constSizeArg.isNonNegative()) {
272  // Dynamic shape dimension will be folded.
273  newShapeConstants.push_back(constSizeArg.getZExtValue());
274  } else {
275  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
276  newShapeConstants.push_back(ShapedType::kDynamic);
277  dynamicSizes.push_back(dynamicSize);
278  }
279  dynamicDimPos++;
280  }
281 
282  // Create new memref type (which will have fewer dynamic dimensions).
283  MemRefType newMemRefType =
284  MemRefType::Builder(memrefType).setShape(newShapeConstants);
285  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
286 
287  // Create and insert the alloc op for the new memref.
288  auto newAlloc = rewriter.create<AllocLikeOp>(
289  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
290  alloc.getAlignmentAttr());
291  // Insert a cast so we have the same type as the old alloc.
292  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
293  return success();
294  }
295 };
296 
297 /// Fold alloc operations with no users or only store and dealloc uses.
298 template <typename T>
299 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
301 
302  LogicalResult matchAndRewrite(T alloc,
303  PatternRewriter &rewriter) const override {
304  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
305  if (auto storeOp = dyn_cast<StoreOp>(op))
306  return storeOp.getValue() == alloc;
307  return !isa<DeallocOp>(op);
308  }))
309  return failure();
310 
311  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
312  rewriter.eraseOp(user);
313 
314  rewriter.eraseOp(alloc);
315  return success();
316  }
317 };
318 } // namespace
319 
320 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
321  MLIRContext *context) {
322  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
323 }
324 
325 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
326  MLIRContext *context) {
327  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
328  context);
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // ReallocOp
333 //===----------------------------------------------------------------------===//
334 
335 LogicalResult ReallocOp::verify() {
336  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
337  MemRefType resultType = getType();
338 
339  // The source memref should have identity layout (or none).
340  if (!sourceType.getLayout().isIdentity())
341  return emitError("unsupported layout for source memref type ")
342  << sourceType;
343 
344  // The result memref should have identity layout (or none).
345  if (!resultType.getLayout().isIdentity())
346  return emitError("unsupported layout for result memref type ")
347  << resultType;
348 
349  // The source memref and the result memref should be in the same memory space.
350  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
351  return emitError("different memory spaces specified for source memref "
352  "type ")
353  << sourceType << " and result memref type " << resultType;
354 
355  // The source memref and the result memref should have the same element type.
356  if (sourceType.getElementType() != resultType.getElementType())
357  return emitError("different element types specified for source memref "
358  "type ")
359  << sourceType << " and result memref type " << resultType;
360 
361  // Verify that we have the dynamic dimension operand when it is needed.
362  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
363  return emitError("missing dimension operand for result type ")
364  << resultType;
365  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
366  return emitError("unnecessary dimension operand for result type ")
367  << resultType;
368 
369  return success();
370 }
371 
372 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
373  MLIRContext *context) {
374  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // AllocaScopeOp
379 //===----------------------------------------------------------------------===//
380 
382  bool printBlockTerminators = false;
383 
384  p << ' ';
385  if (!getResults().empty()) {
386  p << " -> (" << getResultTypes() << ")";
387  printBlockTerminators = true;
388  }
389  p << ' ';
390  p.printRegion(getBodyRegion(),
391  /*printEntryBlockArgs=*/false,
392  /*printBlockTerminators=*/printBlockTerminators);
393  p.printOptionalAttrDict((*this)->getAttrs());
394 }
395 
396 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
397  // Create a region for the body.
398  result.regions.reserve(1);
399  Region *bodyRegion = result.addRegion();
400 
401  // Parse optional results type list.
402  if (parser.parseOptionalArrowTypeList(result.types))
403  return failure();
404 
405  // Parse the body region.
406  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
407  return failure();
408  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
409  result.location);
410 
411  // Parse the optional attribute list.
412  if (parser.parseOptionalAttrDict(result.attributes))
413  return failure();
414 
415  return success();
416 }
417 
418 void AllocaScopeOp::getSuccessorRegions(
420  if (!point.isParent()) {
421  regions.push_back(RegionSuccessor(getResults()));
422  return;
423  }
424 
425  regions.push_back(RegionSuccessor(&getBodyRegion()));
426 }
427 
428 /// Given an operation, return whether this op is guaranteed to
429 /// allocate an AutomaticAllocationScopeResource
431  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
432  if (!interface)
433  return false;
434  for (auto res : op->getResults()) {
435  if (auto effect =
436  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
437  if (isa<SideEffects::AutomaticAllocationScopeResource>(
438  effect->getResource()))
439  return true;
440  }
441  }
442  return false;
443 }
444 
445 /// Given an operation, return whether this op itself could
446 /// allocate an AutomaticAllocationScopeResource. Note that
447 /// this will not check whether an operation contained within
448 /// the op can allocate.
450  // This op itself doesn't create a stack allocation,
451  // the inner allocation should be handled separately.
453  return false;
454  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
455  if (!interface)
456  return true;
457  for (auto res : op->getResults()) {
458  if (auto effect =
459  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
460  if (isa<SideEffects::AutomaticAllocationScopeResource>(
461  effect->getResource()))
462  return true;
463  }
464  }
465  return false;
466 }
467 
468 /// Return whether this op is the last non terminating op
469 /// in a region. That is to say, it is in a one-block region
470 /// and is only followed by a terminator. This prevents
471 /// extending the lifetime of allocations.
473  return op->getNextNode() == op->getBlock()->getTerminator() &&
474  op->getParentRegion()->getBlocks().size() == 1;
475 }
476 
477 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
478 /// or it contains no allocation.
479 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
481 
482  LogicalResult matchAndRewrite(AllocaScopeOp op,
483  PatternRewriter &rewriter) const override {
484  bool hasPotentialAlloca =
485  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
486  if (alloc == op)
487  return WalkResult::advance();
489  return WalkResult::interrupt();
490  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
491  return WalkResult::skip();
492  return WalkResult::advance();
493  }).wasInterrupted();
494 
495  // If this contains no potential allocation, it is always legal to
496  // inline. Otherwise, consider two conditions:
497  if (hasPotentialAlloca) {
498  // If the parent isn't an allocation scope, or we are not the last
499  // non-terminator op in the parent, we will extend the lifetime.
500  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
501  return failure();
502  if (!lastNonTerminatorInRegion(op))
503  return failure();
504  }
505 
506  Block *block = &op.getRegion().front();
507  Operation *terminator = block->getTerminator();
508  ValueRange results = terminator->getOperands();
509  rewriter.inlineBlockBefore(block, op);
510  rewriter.replaceOp(op, results);
511  rewriter.eraseOp(terminator);
512  return success();
513  }
514 };
515 
516 /// Move allocations into an allocation scope, if it is legal to
517 /// move them (e.g. their operands are available at the location
518 /// the op would be moved to).
519 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
521 
522  LogicalResult matchAndRewrite(AllocaScopeOp op,
523  PatternRewriter &rewriter) const override {
524 
525  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
526  return failure();
527 
528  Operation *lastParentWithoutScope = op->getParentOp();
529 
530  if (!lastParentWithoutScope ||
531  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
532  return failure();
533 
534  // Only apply to if this is this last non-terminator
535  // op in the block (lest lifetime be extended) of a one
536  // block region
537  if (!lastNonTerminatorInRegion(op) ||
538  !lastNonTerminatorInRegion(lastParentWithoutScope))
539  return failure();
540 
541  while (!lastParentWithoutScope->getParentOp()
543  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
544  if (!lastParentWithoutScope ||
545  !lastNonTerminatorInRegion(lastParentWithoutScope))
546  return failure();
547  }
548  assert(lastParentWithoutScope->getParentOp()
550 
551  Region *containingRegion = nullptr;
552  for (auto &r : lastParentWithoutScope->getRegions()) {
553  if (r.isAncestor(op->getParentRegion())) {
554  assert(containingRegion == nullptr &&
555  "only one region can contain the op");
556  containingRegion = &r;
557  }
558  }
559  assert(containingRegion && "op must be contained in a region");
560 
561  SmallVector<Operation *> toHoist;
562  op->walk([&](Operation *alloc) {
564  return WalkResult::skip();
565 
566  // If any operand is not defined before the location of
567  // lastParentWithoutScope (i.e. where we would hoist to), skip.
568  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
569  return containingRegion->isAncestor(v.getParentRegion());
570  }))
571  return WalkResult::skip();
572  toHoist.push_back(alloc);
573  return WalkResult::advance();
574  });
575 
576  if (toHoist.empty())
577  return failure();
578  rewriter.setInsertionPoint(lastParentWithoutScope);
579  for (auto *op : toHoist) {
580  auto *cloned = rewriter.clone(*op);
581  rewriter.replaceOp(op, cloned->getResults());
582  }
583  return success();
584  }
585 };
586 
587 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
588  MLIRContext *context) {
589  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // AssumeAlignmentOp
594 //===----------------------------------------------------------------------===//
595 
596 LogicalResult AssumeAlignmentOp::verify() {
597  if (!llvm::isPowerOf2_32(getAlignment()))
598  return emitOpError("alignment must be power of 2");
599  return success();
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // CastOp
604 //===----------------------------------------------------------------------===//
605 
606 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
607  setNameFn(getResult(), "cast");
608 }
609 
610 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
611 /// source memref. This is useful to fold a memref.cast into a consuming op
612 /// and implement canonicalization patterns for ops in different dialects that
613 /// may consume the results of memref.cast operations. Such foldable memref.cast
614 /// operations are typically inserted as `view` and `subview` ops are
615 /// canonicalized, to preserve the type compatibility of their uses.
616 ///
617 /// Returns true when all conditions are met:
618 /// 1. source and result are ranked memrefs with strided semantics and same
619 /// element type and rank.
620 /// 2. each of the source's size, offset or stride has more static information
621 /// than the corresponding result's size, offset or stride.
622 ///
623 /// Example 1:
624 /// ```mlir
625 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
626 /// %2 = consumer %1 ... : memref<?x?xf32> ...
627 /// ```
628 ///
629 /// may fold into:
630 ///
631 /// ```mlir
632 /// %2 = consumer %0 ... : memref<8x16xf32> ...
633 /// ```
634 ///
635 /// Example 2:
636 /// ```
637 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
638 /// to memref<?x?xf32>
639 /// consumer %1 : memref<?x?xf32> ...
640 /// ```
641 ///
642 /// may fold into:
643 ///
644 /// ```
645 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
646 /// ```
647 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
648  MemRefType sourceType =
649  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
650  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
651 
652  // Requires ranked MemRefType.
653  if (!sourceType || !resultType)
654  return false;
655 
656  // Requires same elemental type.
657  if (sourceType.getElementType() != resultType.getElementType())
658  return false;
659 
660  // Requires same rank.
661  if (sourceType.getRank() != resultType.getRank())
662  return false;
663 
664  // Only fold casts between strided memref forms.
665  int64_t sourceOffset, resultOffset;
666  SmallVector<int64_t, 4> sourceStrides, resultStrides;
667  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
668  failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
669  return false;
670 
671  // If cast is towards more static sizes along any dimension, don't fold.
672  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
673  auto ss = std::get<0>(it), st = std::get<1>(it);
674  if (ss != st)
675  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
676  return false;
677  }
678 
679  // If cast is towards more static offset along any dimension, don't fold.
680  if (sourceOffset != resultOffset)
681  if (ShapedType::isDynamic(sourceOffset) &&
682  !ShapedType::isDynamic(resultOffset))
683  return false;
684 
685  // If cast is towards more static strides along any dimension, don't fold.
686  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
687  auto ss = std::get<0>(it), st = std::get<1>(it);
688  if (ss != st)
689  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
690  return false;
691  }
692 
693  return true;
694 }
695 
696 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
697  if (inputs.size() != 1 || outputs.size() != 1)
698  return false;
699  Type a = inputs.front(), b = outputs.front();
700  auto aT = llvm::dyn_cast<MemRefType>(a);
701  auto bT = llvm::dyn_cast<MemRefType>(b);
702 
703  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
704  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
705 
706  if (aT && bT) {
707  if (aT.getElementType() != bT.getElementType())
708  return false;
709  if (aT.getLayout() != bT.getLayout()) {
710  int64_t aOffset, bOffset;
711  SmallVector<int64_t, 4> aStrides, bStrides;
712  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
713  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
714  aStrides.size() != bStrides.size())
715  return false;
716 
717  // Strides along a dimension/offset are compatible if the value in the
718  // source memref is static and the value in the target memref is the
719  // same. They are also compatible if either one is dynamic (see
720  // description of MemRefCastOp for details).
721  auto checkCompatible = [](int64_t a, int64_t b) {
722  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
723  };
724  if (!checkCompatible(aOffset, bOffset))
725  return false;
726  for (const auto &aStride : enumerate(aStrides))
727  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
728  return false;
729  }
730  if (aT.getMemorySpace() != bT.getMemorySpace())
731  return false;
732 
733  // They must have the same rank, and any specified dimensions must match.
734  if (aT.getRank() != bT.getRank())
735  return false;
736 
737  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
738  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
739  if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
740  aDim != bDim)
741  return false;
742  }
743  return true;
744  } else {
745  if (!aT && !uaT)
746  return false;
747  if (!bT && !ubT)
748  return false;
749  // Unranked to unranked casting is unsupported
750  if (uaT && ubT)
751  return false;
752 
753  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
754  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
755  if (aEltType != bEltType)
756  return false;
757 
758  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
759  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
760  return aMemSpace == bMemSpace;
761  }
762 
763  return false;
764 }
765 
766 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
767  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // CopyOp
772 //===----------------------------------------------------------------------===//
773 
774 namespace {
775 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
776 /// and element type, the cast can be skipped. Such CastOps only cast the layout
777 /// of the type.
778 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
780 
781  LogicalResult matchAndRewrite(CopyOp copyOp,
782  PatternRewriter &rewriter) const override {
783  bool modified = false;
784 
785  // Check source.
786  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
787  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
788  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
789 
790  if (fromType && toType) {
791  if (fromType.getShape() == toType.getShape() &&
792  fromType.getElementType() == toType.getElementType()) {
793  rewriter.modifyOpInPlace(copyOp, [&] {
794  copyOp.getSourceMutable().assign(castOp.getSource());
795  });
796  modified = true;
797  }
798  }
799  }
800 
801  // Check target.
802  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
803  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
804  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
805 
806  if (fromType && toType) {
807  if (fromType.getShape() == toType.getShape() &&
808  fromType.getElementType() == toType.getElementType()) {
809  rewriter.modifyOpInPlace(copyOp, [&] {
810  copyOp.getTargetMutable().assign(castOp.getSource());
811  });
812  modified = true;
813  }
814  }
815  }
816 
817  return success(modified);
818  }
819 };
820 
821 /// Fold memref.copy(%x, %x).
822 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
824 
825  LogicalResult matchAndRewrite(CopyOp copyOp,
826  PatternRewriter &rewriter) const override {
827  if (copyOp.getSource() != copyOp.getTarget())
828  return failure();
829 
830  rewriter.eraseOp(copyOp);
831  return success();
832  }
833 };
834 
835 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
837 
838  static bool isEmptyMemRef(BaseMemRefType type) {
839  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
840  }
841 
842  LogicalResult matchAndRewrite(CopyOp copyOp,
843  PatternRewriter &rewriter) const override {
844  if (isEmptyMemRef(copyOp.getSource().getType()) ||
845  isEmptyMemRef(copyOp.getTarget().getType())) {
846  rewriter.eraseOp(copyOp);
847  return success();
848  }
849 
850  return failure();
851  }
852 };
853 } // namespace
854 
855 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
856  MLIRContext *context) {
857  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
858 }
859 
860 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
862  /// copy(memrefcast) -> copy
863  bool folded = false;
864  Operation *op = *this;
865  for (OpOperand &operand : op->getOpOperands()) {
866  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
867  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
868  operand.set(castOp.getOperand());
869  folded = true;
870  }
871  }
872  return success(folded);
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // DeallocOp
877 //===----------------------------------------------------------------------===//
878 
879 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
881  /// dealloc(memrefcast) -> dealloc
882  return foldMemRefCast(*this);
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // DimOp
887 //===----------------------------------------------------------------------===//
888 
889 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
890  setNameFn(getResult(), "dim");
891 }
892 
893 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
894  int64_t index) {
895  auto loc = result.location;
896  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
897  build(builder, result, source, indexValue);
898 }
899 
900 std::optional<int64_t> DimOp::getConstantIndex() {
901  return getConstantIntValue(getIndex());
902 }
903 
904 Speculation::Speculatability DimOp::getSpeculatability() {
905  auto constantIndex = getConstantIndex();
906  if (!constantIndex)
908 
909  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
910  if (!rankedSourceType)
912 
913  if (rankedSourceType.getRank() <= constantIndex)
915 
917 }
918 
919 /// Return a map with key being elements in `vals` and data being number of
920 /// occurences of it. Use std::map, since the `vals` here are strides and the
921 /// dynamic stride value is the same as the tombstone value for
922 /// `DenseMap<int64_t>`.
923 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
924  std::map<int64_t, unsigned> numOccurences;
925  for (auto val : vals)
926  numOccurences[val]++;
927  return numOccurences;
928 }
929 
930 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
931 /// to be a subset of `originalType` with some `1` entries erased, return the
932 /// set of indices that specifies which of the entries of `originalShape` are
933 /// dropped to obtain `reducedShape`.
934 /// This accounts for cases where there are multiple unit-dims, but only a
935 /// subset of those are dropped. For MemRefTypes these can be disambiguated
936 /// using the strides. If a dimension is dropped the stride must be dropped too.
937 static FailureOr<llvm::SmallBitVector>
938 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
939  ArrayRef<OpFoldResult> sizes) {
940  llvm::SmallBitVector unusedDims(originalType.getRank());
941  if (originalType.getRank() == reducedType.getRank())
942  return unusedDims;
943 
944  for (const auto &dim : llvm::enumerate(sizes))
945  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
946  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
947  unusedDims.set(dim.index());
948 
949  // Early exit for the case where the number of unused dims matches the number
950  // of ranks reduced.
951  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
952  originalType.getRank())
953  return unusedDims;
954 
955  SmallVector<int64_t> originalStrides, candidateStrides;
956  int64_t originalOffset, candidateOffset;
957  if (failed(
958  getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
959  failed(
960  getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
961  return failure();
962 
963  // For memrefs, a dimension is truly dropped if its corresponding stride is
964  // also dropped. This is particularly important when more than one of the dims
965  // is 1. Track the number of occurences of the strides in the original type
966  // and the candidate type. For each unused dim that stride should not be
967  // present in the candidate type. Note that there could be multiple dimensions
968  // that have the same size. We dont need to exactly figure out which dim
969  // corresponds to which stride, we just need to verify that the number of
970  // reptitions of a stride in the original + number of unused dims with that
971  // stride == number of repititions of a stride in the candidate.
972  std::map<int64_t, unsigned> currUnaccountedStrides =
973  getNumOccurences(originalStrides);
974  std::map<int64_t, unsigned> candidateStridesNumOccurences =
975  getNumOccurences(candidateStrides);
976  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
977  if (!unusedDims.test(dim))
978  continue;
979  int64_t originalStride = originalStrides[dim];
980  if (currUnaccountedStrides[originalStride] >
981  candidateStridesNumOccurences[originalStride]) {
982  // This dim can be treated as dropped.
983  currUnaccountedStrides[originalStride]--;
984  continue;
985  }
986  if (currUnaccountedStrides[originalStride] ==
987  candidateStridesNumOccurences[originalStride]) {
988  // The stride for this is not dropped. Keep as is.
989  unusedDims.reset(dim);
990  continue;
991  }
992  if (currUnaccountedStrides[originalStride] <
993  candidateStridesNumOccurences[originalStride]) {
994  // This should never happen. Cant have a stride in the reduced rank type
995  // that wasnt in the original one.
996  return failure();
997  }
998  }
999 
1000  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1001  originalType.getRank())
1002  return failure();
1003  return unusedDims;
1004 }
1005 
1006 llvm::SmallBitVector SubViewOp::getDroppedDims() {
1007  MemRefType sourceType = getSourceType();
1008  MemRefType resultType = getType();
1009  FailureOr<llvm::SmallBitVector> unusedDims =
1010  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1011  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1012  return *unusedDims;
1013 }
1014 
1015 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1016  // All forms of folding require a known index.
1017  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1018  if (!index)
1019  return {};
1020 
1021  // Folding for unranked types (UnrankedMemRefType) is not supported.
1022  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1023  if (!memrefType)
1024  return {};
1025 
1026  // Out of bound indices produce undefined behavior but are still valid IR.
1027  // Don't choke on them.
1028  int64_t indexVal = index.getInt();
1029  if (indexVal < 0 || indexVal >= memrefType.getRank())
1030  return {};
1031 
1032  // Fold if the shape extent along the given index is known.
1033  if (!memrefType.isDynamicDim(index.getInt())) {
1034  Builder builder(getContext());
1035  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1036  }
1037 
1038  // The size at the given index is now known to be a dynamic size.
1039  unsigned unsignedIndex = index.getValue().getZExtValue();
1040 
1041  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1042  Operation *definingOp = getSource().getDefiningOp();
1043 
1044  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1045  return *(alloc.getDynamicSizes().begin() +
1046  memrefType.getDynamicDimIndex(unsignedIndex));
1047 
1048  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1049  return *(alloca.getDynamicSizes().begin() +
1050  memrefType.getDynamicDimIndex(unsignedIndex));
1051 
1052  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1053  return *(view.getDynamicSizes().begin() +
1054  memrefType.getDynamicDimIndex(unsignedIndex));
1055 
1056  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1057  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1058  unsigned resultIndex = 0;
1059  unsigned sourceRank = subview.getSourceType().getRank();
1060  unsigned sourceIndex = 0;
1061  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1062  if (unusedDims.test(i))
1063  continue;
1064  if (resultIndex == unsignedIndex) {
1065  sourceIndex = i;
1066  break;
1067  }
1068  resultIndex++;
1069  }
1070  assert(subview.isDynamicSize(sourceIndex) &&
1071  "expected dynamic subview size");
1072  return subview.getDynamicSize(sourceIndex);
1073  }
1074 
1075  if (auto sizeInterface =
1076  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1077  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1078  "Expected dynamic subview size");
1079  return sizeInterface.getDynamicSize(unsignedIndex);
1080  }
1081 
1082  // dim(memrefcast) -> dim
1083  if (succeeded(foldMemRefCast(*this)))
1084  return getResult();
1085 
1086  return {};
1087 }
1088 
1089 namespace {
1090 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1091 /// operand.
1092 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1094 
1095  LogicalResult matchAndRewrite(DimOp dim,
1096  PatternRewriter &rewriter) const override {
1097  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1098 
1099  if (!reshape)
1100  return rewriter.notifyMatchFailure(
1101  dim, "Dim op is not defined by a reshape op.");
1102 
1103  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1104  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1105  // cheaply check that either of the following conditions hold:
1106  // 1. dim.getIndex() is defined in the same block as reshape but before
1107  // reshape.
1108  // 2. dim.getIndex() is defined in a parent block of
1109  // reshape.
1110 
1111  // Check condition 1
1112  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1113  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1114  if (reshape->isBeforeInBlock(definingOp)) {
1115  return rewriter.notifyMatchFailure(
1116  dim,
1117  "dim.getIndex is not defined before reshape in the same block.");
1118  }
1119  } // else dim.getIndex is a block argument to reshape->getBlock and
1120  // dominates reshape
1121  } // Check condition 2
1122  else if (dim->getBlock() != reshape->getBlock() &&
1123  !dim.getIndex().getParentRegion()->isProperAncestor(
1124  reshape->getParentRegion())) {
1125  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1126  // already know dim.getIndex() dominates reshape without calling
1127  // `isProperAncestor`
1128  return rewriter.notifyMatchFailure(
1129  dim, "dim.getIndex does not dominate reshape.");
1130  }
1131 
1132  // Place the load directly after the reshape to ensure that the shape memref
1133  // was not mutated.
1134  rewriter.setInsertionPointAfter(reshape);
1135  Location loc = dim.getLoc();
1136  Value load =
1137  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1138  if (load.getType() != dim.getType())
1139  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1140  rewriter.replaceOp(dim, load);
1141  return success();
1142  }
1143 };
1144 
1145 } // namespace
1146 
1147 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1148  MLIRContext *context) {
1149  results.add<DimOfMemRefReshape>(context);
1150 }
1151 
1152 // ---------------------------------------------------------------------------
1153 // DmaStartOp
1154 // ---------------------------------------------------------------------------
1155 
1156 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1157  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1158  ValueRange destIndices, Value numElements,
1159  Value tagMemRef, ValueRange tagIndices, Value stride,
1160  Value elementsPerStride) {
1161  result.addOperands(srcMemRef);
1162  result.addOperands(srcIndices);
1163  result.addOperands(destMemRef);
1164  result.addOperands(destIndices);
1165  result.addOperands({numElements, tagMemRef});
1166  result.addOperands(tagIndices);
1167  if (stride)
1168  result.addOperands({stride, elementsPerStride});
1169 }
1170 
1172  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1173  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1174  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1175  if (isStrided())
1176  p << ", " << getStride() << ", " << getNumElementsPerStride();
1177 
1178  p.printOptionalAttrDict((*this)->getAttrs());
1179  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1180  << ", " << getTagMemRef().getType();
1181 }
1182 
1183 // Parse DmaStartOp.
1184 // Ex:
1185 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1186 // %tag[%index], %stride, %num_elt_per_stride :
1187 // : memref<3076 x f32, 0>,
1188 // memref<1024 x f32, 2>,
1189 // memref<1 x i32>
1190 //
1191 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1192  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1194  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1196  OpAsmParser::UnresolvedOperand numElementsInfo;
1197  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1200 
1201  SmallVector<Type, 3> types;
1202  auto indexType = parser.getBuilder().getIndexType();
1203 
1204  // Parse and resolve the following list of operands:
1205  // *) source memref followed by its indices (in square brackets).
1206  // *) destination memref followed by its indices (in square brackets).
1207  // *) dma size in KiB.
1208  if (parser.parseOperand(srcMemRefInfo) ||
1209  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1210  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1211  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1212  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1213  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1214  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1215  return failure();
1216 
1217  // Parse optional stride and elements per stride.
1218  if (parser.parseTrailingOperandList(strideInfo))
1219  return failure();
1220 
1221  bool isStrided = strideInfo.size() == 2;
1222  if (!strideInfo.empty() && !isStrided) {
1223  return parser.emitError(parser.getNameLoc(),
1224  "expected two stride related operands");
1225  }
1226 
1227  if (parser.parseColonTypeList(types))
1228  return failure();
1229  if (types.size() != 3)
1230  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1231 
1232  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1233  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1234  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1235  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1236  // size should be an index.
1237  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1238  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1239  // tag indices should be index.
1240  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1241  return failure();
1242 
1243  if (isStrided) {
1244  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1245  return failure();
1246  }
1247 
1248  return success();
1249 }
1250 
1251 LogicalResult DmaStartOp::verify() {
1252  unsigned numOperands = getNumOperands();
1253 
1254  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1255  // the number of elements.
1256  if (numOperands < 4)
1257  return emitOpError("expected at least 4 operands");
1258 
1259  // Check types of operands. The order of these calls is important: the later
1260  // calls rely on some type properties to compute the operand position.
1261  // 1. Source memref.
1262  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1263  return emitOpError("expected source to be of memref type");
1264  if (numOperands < getSrcMemRefRank() + 4)
1265  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1266  << " operands";
1267  if (!getSrcIndices().empty() &&
1268  !llvm::all_of(getSrcIndices().getTypes(),
1269  [](Type t) { return t.isIndex(); }))
1270  return emitOpError("expected source indices to be of index type");
1271 
1272  // 2. Destination memref.
1273  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1274  return emitOpError("expected destination to be of memref type");
1275  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1276  if (numOperands < numExpectedOperands)
1277  return emitOpError() << "expected at least " << numExpectedOperands
1278  << " operands";
1279  if (!getDstIndices().empty() &&
1280  !llvm::all_of(getDstIndices().getTypes(),
1281  [](Type t) { return t.isIndex(); }))
1282  return emitOpError("expected destination indices to be of index type");
1283 
1284  // 3. Number of elements.
1285  if (!getNumElements().getType().isIndex())
1286  return emitOpError("expected num elements to be of index type");
1287 
1288  // 4. Tag memref.
1289  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1290  return emitOpError("expected tag to be of memref type");
1291  numExpectedOperands += getTagMemRefRank();
1292  if (numOperands < numExpectedOperands)
1293  return emitOpError() << "expected at least " << numExpectedOperands
1294  << " operands";
1295  if (!getTagIndices().empty() &&
1296  !llvm::all_of(getTagIndices().getTypes(),
1297  [](Type t) { return t.isIndex(); }))
1298  return emitOpError("expected tag indices to be of index type");
1299 
1300  // Optional stride-related operands must be either both present or both
1301  // absent.
1302  if (numOperands != numExpectedOperands &&
1303  numOperands != numExpectedOperands + 2)
1304  return emitOpError("incorrect number of operands");
1305 
1306  // 5. Strides.
1307  if (isStrided()) {
1308  if (!getStride().getType().isIndex() ||
1309  !getNumElementsPerStride().getType().isIndex())
1310  return emitOpError(
1311  "expected stride and num elements per stride to be of type index");
1312  }
1313 
1314  return success();
1315 }
1316 
1317 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1318  SmallVectorImpl<OpFoldResult> &results) {
1319  /// dma_start(memrefcast) -> dma_start
1320  return foldMemRefCast(*this);
1321 }
1322 
1323 // ---------------------------------------------------------------------------
1324 // DmaWaitOp
1325 // ---------------------------------------------------------------------------
1326 
1327 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1328  SmallVectorImpl<OpFoldResult> &results) {
1329  /// dma_wait(memrefcast) -> dma_wait
1330  return foldMemRefCast(*this);
1331 }
1332 
1333 LogicalResult DmaWaitOp::verify() {
1334  // Check that the number of tag indices matches the tagMemRef rank.
1335  unsigned numTagIndices = getTagIndices().size();
1336  unsigned tagMemRefRank = getTagMemRefRank();
1337  if (numTagIndices != tagMemRefRank)
1338  return emitOpError() << "expected tagIndices to have the same number of "
1339  "elements as the tagMemRef rank, expected "
1340  << tagMemRefRank << ", but got " << numTagIndices;
1341  return success();
1342 }
1343 
1344 //===----------------------------------------------------------------------===//
1345 // ExtractAlignedPointerAsIndexOp
1346 //===----------------------------------------------------------------------===//
1347 
1348 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1349  function_ref<void(Value, StringRef)> setNameFn) {
1350  setNameFn(getResult(), "intptr");
1351 }
1352 
1353 //===----------------------------------------------------------------------===//
1354 // ExtractStridedMetadataOp
1355 //===----------------------------------------------------------------------===//
1356 
1357 /// The number and type of the results are inferred from the
1358 /// shape of the source.
1359 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1360  MLIRContext *context, std::optional<Location> location,
1361  ExtractStridedMetadataOp::Adaptor adaptor,
1362  SmallVectorImpl<Type> &inferredReturnTypes) {
1363  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1364  if (!sourceType)
1365  return failure();
1366 
1367  unsigned sourceRank = sourceType.getRank();
1368  IndexType indexType = IndexType::get(context);
1369  auto memrefType =
1370  MemRefType::get({}, sourceType.getElementType(),
1371  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1372  // Base.
1373  inferredReturnTypes.push_back(memrefType);
1374  // Offset.
1375  inferredReturnTypes.push_back(indexType);
1376  // Sizes and strides.
1377  for (unsigned i = 0; i < sourceRank * 2; ++i)
1378  inferredReturnTypes.push_back(indexType);
1379  return success();
1380 }
1381 
1382 void ExtractStridedMetadataOp::getAsmResultNames(
1383  function_ref<void(Value, StringRef)> setNameFn) {
1384  setNameFn(getBaseBuffer(), "base_buffer");
1385  setNameFn(getOffset(), "offset");
1386  // For multi-result to work properly with pretty names and packed syntax `x:3`
1387  // we can only give a pretty name to the first value in the pack.
1388  if (!getSizes().empty()) {
1389  setNameFn(getSizes().front(), "sizes");
1390  setNameFn(getStrides().front(), "strides");
1391  }
1392 }
1393 
1394 /// Helper function to perform the replacement of all constant uses of `values`
1395 /// by a materialized constant extracted from `maybeConstants`.
1396 /// `values` and `maybeConstants` are expected to have the same size.
1397 template <typename Container>
1398 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1399  Container values,
1400  ArrayRef<OpFoldResult> maybeConstants) {
1401  assert(values.size() == maybeConstants.size() &&
1402  " expected values and maybeConstants of the same size");
1403  bool atLeastOneReplacement = false;
1404  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1405  // Don't materialize a constant if there are no uses: this would indice
1406  // infinite loops in the driver.
1407  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1408  continue;
1409  assert(maybeConstant.template is<Attribute>() &&
1410  "The constified value should be either unchanged (i.e., == result) "
1411  "or a constant");
1412  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1413  loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
1414  .getInt());
1415  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1416  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1417  // yet.
1418  op->replaceUsesOfWith(result, constantVal);
1419  atLeastOneReplacement = true;
1420  }
1421  }
1422  return atLeastOneReplacement;
1423 }
1424 
1425 LogicalResult
1426 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1427  SmallVectorImpl<OpFoldResult> &results) {
1428  OpBuilder builder(*this);
1429 
1430  bool atLeastOneReplacement = replaceConstantUsesOf(
1431  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1432  getConstifiedMixedOffset());
1433  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1434  getConstifiedMixedSizes());
1435  atLeastOneReplacement |= replaceConstantUsesOf(
1436  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1437 
1438  return success(atLeastOneReplacement);
1439 }
1440 
1441 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1442  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1443  constifyIndexValues(values, getSource().getType(), getContext(),
1444  getConstantSizes, ShapedType::isDynamic);
1445  return values;
1446 }
1447 
1449 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1450  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1451  constifyIndexValues(values, getSource().getType(), getContext(),
1452  getConstantStrides, ShapedType::isDynamic);
1453  return values;
1454 }
1455 
1456 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1457  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1458  SmallVector<OpFoldResult> values(1, offsetOfr);
1459  constifyIndexValues(values, getSource().getType(), getContext(),
1460  getConstantOffset, ShapedType::isDynamic);
1461  return values[0];
1462 }
1463 
1464 //===----------------------------------------------------------------------===//
1465 // GenericAtomicRMWOp
1466 //===----------------------------------------------------------------------===//
1467 
1468 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1469  Value memref, ValueRange ivs) {
1470  OpBuilder::InsertionGuard g(builder);
1471  result.addOperands(memref);
1472  result.addOperands(ivs);
1473 
1474  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1475  Type elementType = memrefType.getElementType();
1476  result.addTypes(elementType);
1477 
1478  Region *bodyRegion = result.addRegion();
1479  builder.createBlock(bodyRegion);
1480  bodyRegion->addArgument(elementType, memref.getLoc());
1481  }
1482 }
1483 
1484 LogicalResult GenericAtomicRMWOp::verify() {
1485  auto &body = getRegion();
1486  if (body.getNumArguments() != 1)
1487  return emitOpError("expected single number of entry block arguments");
1488 
1489  if (getResult().getType() != body.getArgument(0).getType())
1490  return emitOpError("expected block argument of the same type result type");
1491 
1492  bool hasSideEffects =
1493  body.walk([&](Operation *nestedOp) {
1494  if (isMemoryEffectFree(nestedOp))
1495  return WalkResult::advance();
1496  nestedOp->emitError(
1497  "body of 'memref.generic_atomic_rmw' should contain "
1498  "only operations with no side effects");
1499  return WalkResult::interrupt();
1500  })
1501  .wasInterrupted();
1502  return hasSideEffects ? failure() : success();
1503 }
1504 
1505 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1506  OperationState &result) {
1508  Type memrefType;
1510 
1511  Type indexType = parser.getBuilder().getIndexType();
1512  if (parser.parseOperand(memref) ||
1514  parser.parseColonType(memrefType) ||
1515  parser.resolveOperand(memref, memrefType, result.operands) ||
1516  parser.resolveOperands(ivs, indexType, result.operands))
1517  return failure();
1518 
1519  Region *body = result.addRegion();
1520  if (parser.parseRegion(*body, {}) ||
1521  parser.parseOptionalAttrDict(result.attributes))
1522  return failure();
1523  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1524  return success();
1525 }
1526 
1528  p << ' ' << getMemref() << "[" << getIndices()
1529  << "] : " << getMemref().getType() << ' ';
1530  p.printRegion(getRegion());
1531  p.printOptionalAttrDict((*this)->getAttrs());
1532 }
1533 
1534 //===----------------------------------------------------------------------===//
1535 // AtomicYieldOp
1536 //===----------------------------------------------------------------------===//
1537 
1538 LogicalResult AtomicYieldOp::verify() {
1539  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1540  Type resultType = getResult().getType();
1541  if (parentType != resultType)
1542  return emitOpError() << "types mismatch between yield op: " << resultType
1543  << " and its parent: " << parentType;
1544  return success();
1545 }
1546 
1547 //===----------------------------------------------------------------------===//
1548 // GlobalOp
1549 //===----------------------------------------------------------------------===//
1550 
1552  TypeAttr type,
1553  Attribute initialValue) {
1554  p << type;
1555  if (!op.isExternal()) {
1556  p << " = ";
1557  if (op.isUninitialized())
1558  p << "uninitialized";
1559  else
1560  p.printAttributeWithoutType(initialValue);
1561  }
1562 }
1563 
1564 static ParseResult
1566  Attribute &initialValue) {
1567  Type type;
1568  if (parser.parseType(type))
1569  return failure();
1570 
1571  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1572  if (!memrefType || !memrefType.hasStaticShape())
1573  return parser.emitError(parser.getNameLoc())
1574  << "type should be static shaped memref, but got " << type;
1575  typeAttr = TypeAttr::get(type);
1576 
1577  if (parser.parseOptionalEqual())
1578  return success();
1579 
1580  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1581  initialValue = UnitAttr::get(parser.getContext());
1582  return success();
1583  }
1584 
1585  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1586  if (parser.parseAttribute(initialValue, tensorType))
1587  return failure();
1588  if (!llvm::isa<ElementsAttr>(initialValue))
1589  return parser.emitError(parser.getNameLoc())
1590  << "initial value should be a unit or elements attribute";
1591  return success();
1592 }
1593 
1594 LogicalResult GlobalOp::verify() {
1595  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1596  if (!memrefType || !memrefType.hasStaticShape())
1597  return emitOpError("type should be static shaped memref, but got ")
1598  << getType();
1599 
1600  // Verify that the initial value, if present, is either a unit attribute or
1601  // an elements attribute.
1602  if (getInitialValue().has_value()) {
1603  Attribute initValue = getInitialValue().value();
1604  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1605  return emitOpError("initial value should be a unit or elements "
1606  "attribute, but got ")
1607  << initValue;
1608 
1609  // Check that the type of the initial value is compatible with the type of
1610  // the global variable.
1611  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1612  Type initType = elementsAttr.getType();
1613  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1614  if (initType != tensorType)
1615  return emitOpError("initial value expected to be of type ")
1616  << tensorType << ", but was of type " << initType;
1617  }
1618  }
1619 
1620  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1621  uint64_t alignment = *alignAttr;
1622 
1623  if (!llvm::isPowerOf2_64(alignment))
1624  return emitError() << "alignment attribute value " << alignment
1625  << " is not a power of 2";
1626  }
1627 
1628  // TODO: verify visibility for declarations.
1629  return success();
1630 }
1631 
1632 ElementsAttr GlobalOp::getConstantInitValue() {
1633  auto initVal = getInitialValue();
1634  if (getConstant() && initVal.has_value())
1635  return llvm::cast<ElementsAttr>(initVal.value());
1636  return {};
1637 }
1638 
1639 //===----------------------------------------------------------------------===//
1640 // GetGlobalOp
1641 //===----------------------------------------------------------------------===//
1642 
1643 LogicalResult
1644 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1645  // Verify that the result type is same as the type of the referenced
1646  // memref.global op.
1647  auto global =
1648  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1649  if (!global)
1650  return emitOpError("'")
1651  << getName() << "' does not reference a valid global memref";
1652 
1653  Type resultType = getResult().getType();
1654  if (global.getType() != resultType)
1655  return emitOpError("result type ")
1656  << resultType << " does not match type " << global.getType()
1657  << " of the global memref @" << getName();
1658  return success();
1659 }
1660 
1661 //===----------------------------------------------------------------------===//
1662 // LoadOp
1663 //===----------------------------------------------------------------------===//
1664 
1665 LogicalResult LoadOp::verify() {
1666  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1667  return emitOpError("incorrect number of indices for load, expected ")
1668  << getMemRefType().getRank() << " but got " << getIndices().size();
1669  }
1670  return success();
1671 }
1672 
1673 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1674  /// load(memrefcast) -> load
1675  if (succeeded(foldMemRefCast(*this)))
1676  return getResult();
1677  return OpFoldResult();
1678 }
1679 
1680 //===----------------------------------------------------------------------===//
1681 // MemorySpaceCastOp
1682 //===----------------------------------------------------------------------===//
1683 
1684 void MemorySpaceCastOp::getAsmResultNames(
1685  function_ref<void(Value, StringRef)> setNameFn) {
1686  setNameFn(getResult(), "memspacecast");
1687 }
1688 
1689 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1690  if (inputs.size() != 1 || outputs.size() != 1)
1691  return false;
1692  Type a = inputs.front(), b = outputs.front();
1693  auto aT = llvm::dyn_cast<MemRefType>(a);
1694  auto bT = llvm::dyn_cast<MemRefType>(b);
1695 
1696  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1697  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1698 
1699  if (aT && bT) {
1700  if (aT.getElementType() != bT.getElementType())
1701  return false;
1702  if (aT.getLayout() != bT.getLayout())
1703  return false;
1704  if (aT.getShape() != bT.getShape())
1705  return false;
1706  return true;
1707  }
1708  if (uaT && ubT) {
1709  return uaT.getElementType() == ubT.getElementType();
1710  }
1711  return false;
1712 }
1713 
1714 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1715  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1716  // t2)
1717  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1718  getSourceMutable().assign(parentCast.getSource());
1719  return getResult();
1720  }
1721  return Value{};
1722 }
1723 
1724 //===----------------------------------------------------------------------===//
1725 // PrefetchOp
1726 //===----------------------------------------------------------------------===//
1727 
1729  p << " " << getMemref() << '[';
1731  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1732  p << ", locality<" << getLocalityHint();
1733  p << ">, " << (getIsDataCache() ? "data" : "instr");
1735  (*this)->getAttrs(),
1736  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1737  p << " : " << getMemRefType();
1738 }
1739 
1740 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1741  OpAsmParser::UnresolvedOperand memrefInfo;
1743  IntegerAttr localityHint;
1744  MemRefType type;
1745  StringRef readOrWrite, cacheType;
1746 
1747  auto indexTy = parser.getBuilder().getIndexType();
1748  auto i32Type = parser.getBuilder().getIntegerType(32);
1749  if (parser.parseOperand(memrefInfo) ||
1750  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1751  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1752  parser.parseComma() || parser.parseKeyword("locality") ||
1753  parser.parseLess() ||
1754  parser.parseAttribute(localityHint, i32Type, "localityHint",
1755  result.attributes) ||
1756  parser.parseGreater() || parser.parseComma() ||
1757  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1758  parser.resolveOperand(memrefInfo, type, result.operands) ||
1759  parser.resolveOperands(indexInfo, indexTy, result.operands))
1760  return failure();
1761 
1762  if (readOrWrite != "read" && readOrWrite != "write")
1763  return parser.emitError(parser.getNameLoc(),
1764  "rw specifier has to be 'read' or 'write'");
1765  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1766  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1767 
1768  if (cacheType != "data" && cacheType != "instr")
1769  return parser.emitError(parser.getNameLoc(),
1770  "cache type has to be 'data' or 'instr'");
1771 
1772  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1773  parser.getBuilder().getBoolAttr(cacheType == "data"));
1774 
1775  return success();
1776 }
1777 
1778 LogicalResult PrefetchOp::verify() {
1779  if (getNumOperands() != 1 + getMemRefType().getRank())
1780  return emitOpError("too few indices");
1781 
1782  return success();
1783 }
1784 
1785 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1786  SmallVectorImpl<OpFoldResult> &results) {
1787  // prefetch(memrefcast) -> prefetch
1788  return foldMemRefCast(*this);
1789 }
1790 
1791 //===----------------------------------------------------------------------===//
1792 // RankOp
1793 //===----------------------------------------------------------------------===//
1794 
1795 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1796  // Constant fold rank when the rank of the operand is known.
1797  auto type = getOperand().getType();
1798  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1799  if (shapedType && shapedType.hasRank())
1800  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1801  return IntegerAttr();
1802 }
1803 
1804 //===----------------------------------------------------------------------===//
1805 // ReinterpretCastOp
1806 //===----------------------------------------------------------------------===//
1807 
1808 void ReinterpretCastOp::getAsmResultNames(
1809  function_ref<void(Value, StringRef)> setNameFn) {
1810  setNameFn(getResult(), "reinterpret_cast");
1811 }
1812 
1813 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1814 /// `staticSizes` and `staticStrides` are automatically filled with
1815 /// source-memref-rank sentinel values that encode dynamic entries.
1816 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1817  MemRefType resultType, Value source,
1818  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1819  ArrayRef<OpFoldResult> strides,
1820  ArrayRef<NamedAttribute> attrs) {
1821  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1822  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1823  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1824  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1825  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1826  result.addAttributes(attrs);
1827  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1828  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1829  b.getDenseI64ArrayAttr(staticSizes),
1830  b.getDenseI64ArrayAttr(staticStrides));
1831 }
1832 
1833 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1834  Value source, OpFoldResult offset,
1835  ArrayRef<OpFoldResult> sizes,
1836  ArrayRef<OpFoldResult> strides,
1837  ArrayRef<NamedAttribute> attrs) {
1838  auto sourceType = cast<BaseMemRefType>(source.getType());
1839  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1840  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1841  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1842  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1843  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1844  auto stridedLayout = StridedLayoutAttr::get(
1845  b.getContext(), staticOffsets.front(), staticStrides);
1846  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1847  stridedLayout, sourceType.getMemorySpace());
1848  build(b, result, resultType, source, offset, sizes, strides, attrs);
1849 }
1850 
1851 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1852  MemRefType resultType, Value source,
1853  int64_t offset, ArrayRef<int64_t> sizes,
1854  ArrayRef<int64_t> strides,
1855  ArrayRef<NamedAttribute> attrs) {
1856  SmallVector<OpFoldResult> sizeValues =
1857  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1858  return b.getI64IntegerAttr(v);
1859  }));
1860  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1861  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1862  return b.getI64IntegerAttr(v);
1863  }));
1864  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1865  strideValues, attrs);
1866 }
1867 
1868 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1869  MemRefType resultType, Value source, Value offset,
1870  ValueRange sizes, ValueRange strides,
1871  ArrayRef<NamedAttribute> attrs) {
1872  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1873  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1874  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1875  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1876  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1877 }
1878 
1879 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1880 // completed automatically, like we have for subview and extract_slice.
1881 LogicalResult ReinterpretCastOp::verify() {
1882  // The source and result memrefs should be in the same memory space.
1883  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1884  auto resultType = llvm::cast<MemRefType>(getType());
1885  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1886  return emitError("different memory spaces specified for source type ")
1887  << srcType << " and result memref type " << resultType;
1888  if (srcType.getElementType() != resultType.getElementType())
1889  return emitError("different element types specified for source type ")
1890  << srcType << " and result memref type " << resultType;
1891 
1892  // Match sizes in result memref type and in static_sizes attribute.
1893  for (auto [idx, resultSize, expectedSize] :
1894  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1895  if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
1896  return emitError("expected result type with size = ")
1897  << (ShapedType::isDynamic(expectedSize)
1898  ? std::string("dynamic")
1899  : std::to_string(expectedSize))
1900  << " instead of " << resultSize << " in dim = " << idx;
1901  }
1902 
1903  // Match offset and strides in static_offset and static_strides attributes. If
1904  // result memref type has no affine map specified, this will assume an
1905  // identity layout.
1906  int64_t resultOffset;
1907  SmallVector<int64_t, 4> resultStrides;
1908  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1909  return emitError("expected result type to have strided layout but found ")
1910  << resultType;
1911 
1912  // Match offset in result memref type and in static_offsets attribute.
1913  int64_t expectedOffset = getStaticOffsets().front();
1914  if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
1915  return emitError("expected result type with offset = ")
1916  << (ShapedType::isDynamic(expectedOffset)
1917  ? std::string("dynamic")
1918  : std::to_string(expectedOffset))
1919  << " instead of " << resultOffset;
1920 
1921  // Match strides in result memref type and in static_strides attribute.
1922  for (auto [idx, resultStride, expectedStride] :
1923  llvm::enumerate(resultStrides, getStaticStrides())) {
1924  if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
1925  return emitError("expected result type with stride = ")
1926  << (ShapedType::isDynamic(expectedStride)
1927  ? std::string("dynamic")
1928  : std::to_string(expectedStride))
1929  << " instead of " << resultStride << " in dim = " << idx;
1930  }
1931 
1932  return success();
1933 }
1934 
1935 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1936  Value src = getSource();
1937  auto getPrevSrc = [&]() -> Value {
1938  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1939  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1940  return prev.getSource();
1941 
1942  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1943  if (auto prev = src.getDefiningOp<CastOp>())
1944  return prev.getSource();
1945 
1946  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1947  // are 0.
1948  if (auto prev = src.getDefiningOp<SubViewOp>())
1949  if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1950  return isConstantIntValue(val, 0);
1951  }))
1952  return prev.getSource();
1953 
1954  return nullptr;
1955  };
1956 
1957  if (auto prevSrc = getPrevSrc()) {
1958  getSourceMutable().assign(prevSrc);
1959  return getResult();
1960  }
1961 
1962  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1963  if (!ShapedType::isDynamicShape(getType().getShape()) &&
1964  src.getType() == getType() && getStaticOffsets().front() == 0) {
1965  return src;
1966  }
1967 
1968  return nullptr;
1969 }
1970 
1971 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1974  ShapedType::isDynamic);
1975  return values;
1976 }
1977 
1978 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1979  SmallVector<OpFoldResult> values = getMixedStrides();
1981  ShapedType::isDynamic);
1982  return values;
1983 }
1984 
1985 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1986  SmallVector<OpFoldResult> values = getMixedOffsets();
1987  assert(values.size() == 1 &&
1988  "reinterpret_cast must have one and only one offset");
1990  ShapedType::isDynamic);
1991  return values[0];
1992 }
1993 
1994 namespace {
1995 /// Replace the sequence:
1996 /// ```
1997 /// base, offset, sizes, strides = extract_strided_metadata src
1998 /// dst = reinterpret_cast base to offset, sizes, strides
1999 /// ```
2000 /// With
2001 ///
2002 /// ```
2003 /// dst = memref.cast src
2004 /// ```
2005 ///
2006 /// Note: The cast operation is only inserted when the type of dst and src
2007 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
2008 ///
2009 /// This pattern also matches when the offset, sizes, and strides don't come
2010 /// directly from the `extract_strided_metadata`'s results but it can be
2011 /// statically proven that they would hold the same values.
2012 ///
2013 /// For instance, the following sequence would be replaced:
2014 /// ```
2015 /// base, offset, sizes, strides =
2016 /// extract_strided_metadata memref : memref<3x4xty>
2017 /// dst = reinterpret_cast base to 0, [3, 4], strides
2018 /// ```
2019 /// Because we know (thanks to the type of the input memref) that variable
2020 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
2021 ///
2022 /// Similarly, the following sequence would be replaced:
2023 /// ```
2024 /// c0 = arith.constant 0
2025 /// c4 = arith.constant 4
2026 /// base, offset, sizes, strides =
2027 /// extract_strided_metadata memref : memref<3x4xty>
2028 /// dst = reinterpret_cast base to c0, [3, c4], strides
2029 /// ```
2030 /// Because we know that `offset`and `c0` will hold 0
2031 /// and `c4` will hold 4.
2032 struct ReinterpretCastOpExtractStridedMetadataFolder
2033  : public OpRewritePattern<ReinterpretCastOp> {
2034 public:
2036 
2037  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2038  PatternRewriter &rewriter) const override {
2039  auto extractStridedMetadata =
2040  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2041  if (!extractStridedMetadata)
2042  return failure();
2043  // Check if the reinterpret cast reconstructs a memref with the exact same
2044  // properties as the extract strided metadata.
2045 
2046  // First, check that the strides are the same.
2047  SmallVector<OpFoldResult> extractStridesOfr =
2048  extractStridedMetadata.getConstifiedMixedStrides();
2049  SmallVector<OpFoldResult> reinterpretStridesOfr =
2050  op.getConstifiedMixedStrides();
2051  if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2052  return failure();
2053 
2054  unsigned rank = op.getType().getRank();
2055  for (unsigned i = 0; i < rank; ++i) {
2056  if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2057  return failure();
2058  }
2059 
2060  // Second, check the sizes.
2061  assert(extractStridedMetadata.getSizes().size() ==
2062  op.getMixedSizes().size() &&
2063  "Strides and sizes rank must match");
2064  SmallVector<OpFoldResult> extractSizesOfr =
2065  extractStridedMetadata.getConstifiedMixedSizes();
2066  SmallVector<OpFoldResult> reinterpretSizesOfr =
2067  op.getConstifiedMixedSizes();
2068  for (unsigned i = 0; i < rank; ++i) {
2069  if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2070  return failure();
2071  }
2072  // Finally, check the offset.
2073  assert(op.getMixedOffsets().size() == 1 &&
2074  "reinterpret_cast with more than one offset should have been "
2075  "rejected by the verifier");
2076  OpFoldResult extractOffsetOfr =
2077  extractStridedMetadata.getConstifiedMixedOffset();
2078  OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2079  if (extractOffsetOfr != reinterpretOffsetOfr)
2080  return failure();
2081 
2082  // At this point, we know that the back and forth between extract strided
2083  // metadata and reinterpret cast is a noop. However, the final type of the
2084  // reinterpret cast may not be exactly the same as the original memref.
2085  // E.g., it could be changing a dimension from static to dynamic. Check that
2086  // here and add a cast if necessary.
2087  Type srcTy = extractStridedMetadata.getSource().getType();
2088  if (srcTy == op.getResult().getType())
2089  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2090  else
2091  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2092  extractStridedMetadata.getSource());
2093 
2094  return success();
2095  }
2096 };
2097 } // namespace
2098 
2099 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2100  MLIRContext *context) {
2101  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2102 }
2103 
2104 //===----------------------------------------------------------------------===//
2105 // Reassociative reshape ops
2106 //===----------------------------------------------------------------------===//
2107 
2108 void CollapseShapeOp::getAsmResultNames(
2109  function_ref<void(Value, StringRef)> setNameFn) {
2110  setNameFn(getResult(), "collapse_shape");
2111 }
2112 
2113 void ExpandShapeOp::getAsmResultNames(
2114  function_ref<void(Value, StringRef)> setNameFn) {
2115  setNameFn(getResult(), "expand_shape");
2116 }
2117 
2118 LogicalResult ExpandShapeOp::reifyResultShapes(
2119  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2120  reifiedResultShapes = {
2121  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2122  return success();
2123 }
2124 
2125 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2126 /// result and operand. Layout maps are verified separately.
2127 ///
2128 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2129 /// allowed in a reassocation group.
2130 static LogicalResult
2132  ArrayRef<int64_t> expandedShape,
2133  ArrayRef<ReassociationIndices> reassociation,
2134  bool allowMultipleDynamicDimsPerGroup) {
2135  // There must be one reassociation group per collapsed dimension.
2136  if (collapsedShape.size() != reassociation.size())
2137  return op->emitOpError("invalid number of reassociation groups: found ")
2138  << reassociation.size() << ", expected " << collapsedShape.size();
2139 
2140  // The next expected expanded dimension index (while iterating over
2141  // reassociation indices).
2142  int64_t nextDim = 0;
2143  for (const auto &it : llvm::enumerate(reassociation)) {
2144  ReassociationIndices group = it.value();
2145  int64_t collapsedDim = it.index();
2146 
2147  bool foundDynamic = false;
2148  for (int64_t expandedDim : group) {
2149  if (expandedDim != nextDim++)
2150  return op->emitOpError("reassociation indices must be contiguous");
2151 
2152  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2153  return op->emitOpError("reassociation index ")
2154  << expandedDim << " is out of bounds";
2155 
2156  // Check if there are multiple dynamic dims in a reassociation group.
2157  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2158  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2159  return op->emitOpError(
2160  "at most one dimension in a reassociation group may be dynamic");
2161  foundDynamic = true;
2162  }
2163  }
2164 
2165  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2166  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2167  return op->emitOpError("collapsed dim (")
2168  << collapsedDim
2169  << ") must be dynamic if and only if reassociation group is "
2170  "dynamic";
2171 
2172  // If all dims in the reassociation group are static, the size of the
2173  // collapsed dim can be verified.
2174  if (!foundDynamic) {
2175  int64_t groupSize = 1;
2176  for (int64_t expandedDim : group)
2177  groupSize *= expandedShape[expandedDim];
2178  if (groupSize != collapsedShape[collapsedDim])
2179  return op->emitOpError("collapsed dim size (")
2180  << collapsedShape[collapsedDim]
2181  << ") must equal reassociation group size (" << groupSize << ")";
2182  }
2183  }
2184 
2185  if (collapsedShape.empty()) {
2186  // Rank 0: All expanded dimensions must be 1.
2187  for (int64_t d : expandedShape)
2188  if (d != 1)
2189  return op->emitOpError(
2190  "rank 0 memrefs can only be extended/collapsed with/from ones");
2191  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2192  // Rank >= 1: Number of dimensions among all reassociation groups must match
2193  // the result memref rank.
2194  return op->emitOpError("expanded rank (")
2195  << expandedShape.size()
2196  << ") inconsistent with number of reassociation indices (" << nextDim
2197  << ")";
2198  }
2199 
2200  return success();
2201 }
2202 
2203 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2204  return getSymbolLessAffineMaps(getReassociationExprs());
2205 }
2206 
2207 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2209  getReassociationIndices());
2210 }
2211 
2212 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2213  return getSymbolLessAffineMaps(getReassociationExprs());
2214 }
2215 
2216 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2218  getReassociationIndices());
2219 }
2220 
2221 /// Compute the layout map after expanding a given source MemRef type with the
2222 /// specified reassociation indices.
2223 static FailureOr<StridedLayoutAttr>
2224 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2225  ArrayRef<ReassociationIndices> reassociation) {
2226  int64_t srcOffset;
2227  SmallVector<int64_t> srcStrides;
2228  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2229  return failure();
2230  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2231 
2232  // 1-1 mapping between srcStrides and reassociation packs.
2233  // Each srcStride starts with the given value and gets expanded according to
2234  // the proper entries in resultShape.
2235  // Example:
2236  // srcStrides = [10000, 1 , 100 ],
2237  // reassociations = [ [0], [1], [2, 3, 4]],
2238  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2239  // -> For the purpose of stride calculation, the useful sizes are:
2240  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2241  // resultStrides = [10000, 1, 600, 200, 100]
2242  // Note that a stride does not get expanded along the first entry of each
2243  // shape pack.
2244  SmallVector<int64_t> reverseResultStrides;
2245  reverseResultStrides.reserve(resultShape.size());
2246  unsigned shapeIndex = resultShape.size() - 1;
2247  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2248  ReassociationIndices reassoc = std::get<0>(it);
2249  int64_t currentStrideToExpand = std::get<1>(it);
2250  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2251  reverseResultStrides.push_back(currentStrideToExpand);
2252  currentStrideToExpand =
2253  (SaturatedInteger::wrap(currentStrideToExpand) *
2254  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2255  .asInteger();
2256  }
2257  }
2258  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2259  resultStrides.resize(resultShape.size(), 1);
2260  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2261 }
2262 
2263 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2264  MemRefType srcType, ArrayRef<int64_t> resultShape,
2265  ArrayRef<ReassociationIndices> reassociation) {
2266  if (srcType.getLayout().isIdentity()) {
2267  // If the source is contiguous (i.e., no layout map specified), so is the
2268  // result.
2269  MemRefLayoutAttrInterface layout;
2270  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2271  srcType.getMemorySpace());
2272  }
2273 
2274  // Source may not be contiguous. Compute the layout map.
2275  FailureOr<StridedLayoutAttr> computedLayout =
2276  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2277  if (failed(computedLayout))
2278  return failure();
2279  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2280  srcType.getMemorySpace());
2281 }
2282 
2283 FailureOr<SmallVector<OpFoldResult>>
2284 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2285  MemRefType expandedType,
2286  ArrayRef<ReassociationIndices> reassociation,
2287  ArrayRef<OpFoldResult> inputShape) {
2288  std::optional<SmallVector<OpFoldResult>> outputShape =
2289  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2290  inputShape);
2291  if (!outputShape)
2292  return failure();
2293  return *outputShape;
2294 }
2295 
2296 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2297  Type resultType, Value src,
2298  ArrayRef<ReassociationIndices> reassociation,
2299  ArrayRef<OpFoldResult> outputShape) {
2300  auto [staticOutputShape, dynamicOutputShape] =
2302  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2303  getReassociationIndicesAttribute(builder, reassociation),
2304  dynamicOutputShape, staticOutputShape);
2305 }
2306 
2307 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2308  Type resultType, Value src,
2309  ArrayRef<ReassociationIndices> reassociation) {
2310  SmallVector<OpFoldResult> inputShape =
2311  getMixedSizes(builder, result.location, src);
2312  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2313  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2314  builder, result.location, memrefResultTy, reassociation, inputShape);
2315  // Failure of this assertion usually indicates presence of multiple
2316  // dynamic dimensions in the same reassociation group.
2317  assert(succeeded(outputShape) && "unable to infer output shape");
2318  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2319 }
2320 
2321 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2322  ArrayRef<int64_t> resultShape, Value src,
2323  ArrayRef<ReassociationIndices> reassociation) {
2324  // Only ranked memref source values are supported.
2325  auto srcType = llvm::cast<MemRefType>(src.getType());
2326  FailureOr<MemRefType> resultType =
2327  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2328  // Failure of this assertion usually indicates a problem with the source
2329  // type, e.g., could not get strides/offset.
2330  assert(succeeded(resultType) && "could not compute layout");
2331  build(builder, result, *resultType, src, reassociation);
2332 }
2333 
2334 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2335  ArrayRef<int64_t> resultShape, Value src,
2336  ArrayRef<ReassociationIndices> reassociation,
2337  ArrayRef<OpFoldResult> outputShape) {
2338  // Only ranked memref source values are supported.
2339  auto srcType = llvm::cast<MemRefType>(src.getType());
2340  FailureOr<MemRefType> resultType =
2341  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2342  // Failure of this assertion usually indicates a problem with the source
2343  // type, e.g., could not get strides/offset.
2344  assert(succeeded(resultType) && "could not compute layout");
2345  build(builder, result, *resultType, src, reassociation, outputShape);
2346 }
2347 
2348 LogicalResult ExpandShapeOp::verify() {
2349  MemRefType srcType = getSrcType();
2350  MemRefType resultType = getResultType();
2351 
2352  if (srcType.getRank() > resultType.getRank()) {
2353  auto r0 = srcType.getRank();
2354  auto r1 = resultType.getRank();
2355  return emitOpError("has source rank ")
2356  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2357  << r0 << " > " << r1 << ").";
2358  }
2359 
2360  // Verify result shape.
2361  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2362  resultType.getShape(),
2363  getReassociationIndices(),
2364  /*allowMultipleDynamicDimsPerGroup=*/true)))
2365  return failure();
2366 
2367  // Compute expected result type (including layout map).
2368  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2369  srcType, resultType.getShape(), getReassociationIndices());
2370  if (failed(expectedResultType))
2371  return emitOpError("invalid source layout map");
2372 
2373  // Check actual result type.
2374  if (*expectedResultType != resultType)
2375  return emitOpError("expected expanded type to be ")
2376  << *expectedResultType << " but found " << resultType;
2377 
2378  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2379  return emitOpError("expected number of static shape bounds to be equal to "
2380  "the output rank (")
2381  << resultType.getRank() << ") but found "
2382  << getStaticOutputShape().size() << " inputs instead";
2383 
2384  if ((int64_t)getOutputShape().size() !=
2385  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2386  return emitOpError("mismatch in dynamic dims in output_shape and "
2387  "static_output_shape: static_output_shape has ")
2388  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2389  << " dynamic dims while output_shape has " << getOutputShape().size()
2390  << " values";
2391 
2392  // Verify if provided output shapes are in agreement with output type.
2393  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2394  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2395  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2396  if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2397  return emitOpError("invalid output shape provided at pos ") << pos;
2398  }
2399  }
2400 
2401  return success();
2402 }
2403 
2404 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2405  MLIRContext *context) {
2406  results.add<
2409 }
2410 
2411 /// Compute the layout map after collapsing a given source MemRef type with the
2412 /// specified reassociation indices.
2413 ///
2414 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2415 /// not possible to check this by inspecting a MemRefType in the general case.
2416 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2417 /// be valid (and thus accepted by this function) unless `strict = true`.
2418 static FailureOr<StridedLayoutAttr>
2419 computeCollapsedLayoutMap(MemRefType srcType,
2420  ArrayRef<ReassociationIndices> reassociation,
2421  bool strict = false) {
2422  int64_t srcOffset;
2423  SmallVector<int64_t> srcStrides;
2424  auto srcShape = srcType.getShape();
2425  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2426  return failure();
2427 
2428  // The result stride of a reassociation group is the stride of the last entry
2429  // of the reassociation. (TODO: Should be the minimum stride in the
2430  // reassociation because strides are not necessarily sorted. E.g., when using
2431  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2432  // strides are meaningless and could have any arbitrary value.
2433  SmallVector<int64_t> resultStrides;
2434  resultStrides.reserve(reassociation.size());
2435  for (const ReassociationIndices &reassoc : reassociation) {
2436  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2437  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2438  ref = ref.drop_back();
2439  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2440  resultStrides.push_back(srcStrides[ref.back()]);
2441  } else {
2442  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2443  // the corresponding stride may have to be skipped. (See above comment.)
2444  // Therefore, the result stride cannot be statically determined and must
2445  // be dynamic.
2446  resultStrides.push_back(ShapedType::kDynamic);
2447  }
2448  }
2449 
2450  // Validate that each reassociation group is contiguous.
2451  unsigned resultStrideIndex = resultStrides.size() - 1;
2452  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2453  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2454  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2455  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2456  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2457 
2458  // Both source and result stride must have the same static value. In that
2459  // case, we can be sure, that the dimensions are collapsible (because they
2460  // are contiguous).
2461  // If `strict = false` (default during op verification), we accept cases
2462  // where one or both strides are dynamic. This is best effort: We reject
2463  // ops where obviously non-contiguous dims are collapsed, but accept ops
2464  // where we cannot be sure statically. Such ops may fail at runtime. See
2465  // the op documentation for details.
2466  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2467  if (strict && (stride.saturated || srcStride.saturated))
2468  return failure();
2469 
2470  // Dimensions of size 1 should be skipped, because their strides are
2471  // meaningless and could have any arbitrary value.
2472  if (srcShape[idx - 1] == 1)
2473  continue;
2474 
2475  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2476  return failure();
2477  }
2478  }
2479  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2480 }
2481 
2482 bool CollapseShapeOp::isGuaranteedCollapsible(
2483  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2484  // MemRefs with identity layout are always collapsible.
2485  if (srcType.getLayout().isIdentity())
2486  return true;
2487 
2488  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2489  /*strict=*/true));
2490 }
2491 
2492 MemRefType CollapseShapeOp::computeCollapsedType(
2493  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2494  SmallVector<int64_t> resultShape;
2495  resultShape.reserve(reassociation.size());
2496  for (const ReassociationIndices &group : reassociation) {
2497  auto groupSize = SaturatedInteger::wrap(1);
2498  for (int64_t srcDim : group)
2499  groupSize =
2500  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2501  resultShape.push_back(groupSize.asInteger());
2502  }
2503 
2504  if (srcType.getLayout().isIdentity()) {
2505  // If the source is contiguous (i.e., no layout map specified), so is the
2506  // result.
2507  MemRefLayoutAttrInterface layout;
2508  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2509  srcType.getMemorySpace());
2510  }
2511 
2512  // Source may not be fully contiguous. Compute the layout map.
2513  // Note: Dimensions that are collapsed into a single dim are assumed to be
2514  // contiguous.
2515  FailureOr<StridedLayoutAttr> computedLayout =
2516  computeCollapsedLayoutMap(srcType, reassociation);
2517  assert(succeeded(computedLayout) &&
2518  "invalid source layout map or collapsing non-contiguous dims");
2519  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2520  srcType.getMemorySpace());
2521 }
2522 
2523 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2524  ArrayRef<ReassociationIndices> reassociation,
2525  ArrayRef<NamedAttribute> attrs) {
2526  auto srcType = llvm::cast<MemRefType>(src.getType());
2527  MemRefType resultType =
2528  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2530  getReassociationIndicesAttribute(b, reassociation));
2531  build(b, result, resultType, src, attrs);
2532 }
2533 
2534 LogicalResult CollapseShapeOp::verify() {
2535  MemRefType srcType = getSrcType();
2536  MemRefType resultType = getResultType();
2537 
2538  if (srcType.getRank() < resultType.getRank()) {
2539  auto r0 = srcType.getRank();
2540  auto r1 = resultType.getRank();
2541  return emitOpError("has source rank ")
2542  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2543  << r0 << " < " << r1 << ").";
2544  }
2545 
2546  // Verify result shape.
2547  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2548  srcType.getShape(), getReassociationIndices(),
2549  /*allowMultipleDynamicDimsPerGroup=*/true)))
2550  return failure();
2551 
2552  // Compute expected result type (including layout map).
2553  MemRefType expectedResultType;
2554  if (srcType.getLayout().isIdentity()) {
2555  // If the source is contiguous (i.e., no layout map specified), so is the
2556  // result.
2557  MemRefLayoutAttrInterface layout;
2558  expectedResultType =
2559  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2560  srcType.getMemorySpace());
2561  } else {
2562  // Source may not be fully contiguous. Compute the layout map.
2563  // Note: Dimensions that are collapsed into a single dim are assumed to be
2564  // contiguous.
2565  FailureOr<StridedLayoutAttr> computedLayout =
2566  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2567  if (failed(computedLayout))
2568  return emitOpError(
2569  "invalid source layout map or collapsing non-contiguous dims");
2570  expectedResultType =
2571  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2572  *computedLayout, srcType.getMemorySpace());
2573  }
2574 
2575  if (expectedResultType != resultType)
2576  return emitOpError("expected collapsed type to be ")
2577  << expectedResultType << " but found " << resultType;
2578 
2579  return success();
2580 }
2581 
2583  : public OpRewritePattern<CollapseShapeOp> {
2584 public:
2586 
2587  LogicalResult matchAndRewrite(CollapseShapeOp op,
2588  PatternRewriter &rewriter) const override {
2589  auto cast = op.getOperand().getDefiningOp<CastOp>();
2590  if (!cast)
2591  return failure();
2592 
2593  if (!CastOp::canFoldIntoConsumerOp(cast))
2594  return failure();
2595 
2596  Type newResultType = CollapseShapeOp::computeCollapsedType(
2597  llvm::cast<MemRefType>(cast.getOperand().getType()),
2598  op.getReassociationIndices());
2599 
2600  if (newResultType == op.getResultType()) {
2601  rewriter.modifyOpInPlace(
2602  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2603  } else {
2604  Value newOp = rewriter.create<CollapseShapeOp>(
2605  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2606  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2607  }
2608  return success();
2609  }
2610 };
2611 
2612 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2613  MLIRContext *context) {
2614  results.add<
2616  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2617  memref::DimOp, MemRefType>,
2619 }
2620 
2621 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2622  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2623  adaptor.getOperands());
2624 }
2625 
2626 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2627  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2628  adaptor.getOperands());
2629 }
2630 
2631 //===----------------------------------------------------------------------===//
2632 // ReshapeOp
2633 //===----------------------------------------------------------------------===//
2634 
2635 void ReshapeOp::getAsmResultNames(
2636  function_ref<void(Value, StringRef)> setNameFn) {
2637  setNameFn(getResult(), "reshape");
2638 }
2639 
2640 LogicalResult ReshapeOp::verify() {
2641  Type operandType = getSource().getType();
2642  Type resultType = getResult().getType();
2643 
2644  Type operandElementType =
2645  llvm::cast<ShapedType>(operandType).getElementType();
2646  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2647  if (operandElementType != resultElementType)
2648  return emitOpError("element types of source and destination memref "
2649  "types should be the same");
2650 
2651  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2652  if (!operandMemRefType.getLayout().isIdentity())
2653  return emitOpError("source memref type should have identity affine map");
2654 
2655  int64_t shapeSize =
2656  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2657  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2658  if (resultMemRefType) {
2659  if (!resultMemRefType.getLayout().isIdentity())
2660  return emitOpError("result memref type should have identity affine map");
2661  if (shapeSize == ShapedType::kDynamic)
2662  return emitOpError("cannot use shape operand with dynamic length to "
2663  "reshape to statically-ranked memref type");
2664  if (shapeSize != resultMemRefType.getRank())
2665  return emitOpError(
2666  "length of shape operand differs from the result's memref rank");
2667  }
2668  return success();
2669 }
2670 
2671 //===----------------------------------------------------------------------===//
2672 // StoreOp
2673 //===----------------------------------------------------------------------===//
2674 
2675 LogicalResult StoreOp::verify() {
2676  if (getNumOperands() != 2 + getMemRefType().getRank())
2677  return emitOpError("store index operand count not equal to memref rank");
2678 
2679  return success();
2680 }
2681 
2682 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2683  SmallVectorImpl<OpFoldResult> &results) {
2684  /// store(memrefcast) -> store
2685  return foldMemRefCast(*this, getValueToStore());
2686 }
2687 
2688 //===----------------------------------------------------------------------===//
2689 // SubViewOp
2690 //===----------------------------------------------------------------------===//
2691 
2692 void SubViewOp::getAsmResultNames(
2693  function_ref<void(Value, StringRef)> setNameFn) {
2694  setNameFn(getResult(), "subview");
2695 }
2696 
2697 /// A subview result type can be fully inferred from the source type and the
2698 /// static representation of offsets, sizes and strides. Special sentinels
2699 /// encode the dynamic case.
2700 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2701  ArrayRef<int64_t> staticOffsets,
2702  ArrayRef<int64_t> staticSizes,
2703  ArrayRef<int64_t> staticStrides) {
2704  unsigned rank = sourceMemRefType.getRank();
2705  (void)rank;
2706  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2707  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2708  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2709 
2710  // Extract source offset and strides.
2711  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
2712 
2713  // Compute target offset whose value is:
2714  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2715  int64_t targetOffset = sourceOffset;
2716  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2717  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2718  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2719  SaturatedInteger::wrap(staticOffset) *
2720  SaturatedInteger::wrap(sourceStride))
2721  .asInteger();
2722  }
2723 
2724  // Compute target stride whose value is:
2725  // `sourceStrides_i * staticStrides_i`.
2726  SmallVector<int64_t, 4> targetStrides;
2727  targetStrides.reserve(staticOffsets.size());
2728  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2729  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2730  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2731  SaturatedInteger::wrap(staticStride))
2732  .asInteger());
2733  }
2734 
2735  // The type is now known.
2736  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2737  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2738  targetOffset, targetStrides),
2739  sourceMemRefType.getMemorySpace());
2740 }
2741 
2742 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2743  ArrayRef<OpFoldResult> offsets,
2744  ArrayRef<OpFoldResult> sizes,
2745  ArrayRef<OpFoldResult> strides) {
2746  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2747  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2748  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2749  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2750  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2751  if (!hasValidSizesOffsets(staticOffsets))
2752  return {};
2753  if (!hasValidSizesOffsets(staticSizes))
2754  return {};
2755  if (!hasValidStrides(staticStrides))
2756  return {};
2757  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2758  staticSizes, staticStrides);
2759 }
2760 
2761 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2762  MemRefType sourceRankedTensorType,
2763  ArrayRef<int64_t> offsets,
2764  ArrayRef<int64_t> sizes,
2765  ArrayRef<int64_t> strides) {
2766  auto inferredType = llvm::cast<MemRefType>(
2767  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2768  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2769  "expected ");
2770  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2771  return inferredType;
2772 
2773  // Compute which dimensions are dropped.
2774  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2775  computeRankReductionMask(inferredType.getShape(), resultShape);
2776  assert(dimsToProject.has_value() && "invalid rank reduction");
2777 
2778  // Compute the layout and result type.
2779  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2780  SmallVector<int64_t> rankReducedStrides;
2781  rankReducedStrides.reserve(resultShape.size());
2782  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2783  if (!dimsToProject->contains(idx))
2784  rankReducedStrides.push_back(value);
2785  }
2786  return MemRefType::get(resultShape, inferredType.getElementType(),
2787  StridedLayoutAttr::get(inferredLayout.getContext(),
2788  inferredLayout.getOffset(),
2789  rankReducedStrides),
2790  inferredType.getMemorySpace());
2791 }
2792 
2793 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2794  MemRefType sourceRankedTensorType,
2795  ArrayRef<OpFoldResult> offsets,
2796  ArrayRef<OpFoldResult> sizes,
2797  ArrayRef<OpFoldResult> strides) {
2798  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2799  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2800  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2801  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2802  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2803  return SubViewOp::inferRankReducedResultType(
2804  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2805  staticStrides);
2806 }
2807 
2808 // Build a SubViewOp with mixed static and dynamic entries and custom result
2809 // type. If the type passed is nullptr, it is inferred.
2810 void SubViewOp::build(OpBuilder &b, OperationState &result,
2811  MemRefType resultType, Value source,
2812  ArrayRef<OpFoldResult> offsets,
2813  ArrayRef<OpFoldResult> sizes,
2814  ArrayRef<OpFoldResult> strides,
2815  ArrayRef<NamedAttribute> attrs) {
2816  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2817  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2818  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2819  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2820  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2821  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2822  // Structuring implementation this way avoids duplication between builders.
2823  if (!resultType) {
2824  resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2825  sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2826  }
2827  result.addAttributes(attrs);
2828  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2829  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2830  b.getDenseI64ArrayAttr(staticSizes),
2831  b.getDenseI64ArrayAttr(staticStrides));
2832 }
2833 
2834 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2835 // type.
2836 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2837  ArrayRef<OpFoldResult> offsets,
2838  ArrayRef<OpFoldResult> sizes,
2839  ArrayRef<OpFoldResult> strides,
2840  ArrayRef<NamedAttribute> attrs) {
2841  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2842 }
2843 
2844 // Build a SubViewOp with static entries and inferred result type.
2845 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2846  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2847  ArrayRef<int64_t> strides,
2848  ArrayRef<NamedAttribute> attrs) {
2849  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2850  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2851  return b.getI64IntegerAttr(v);
2852  }));
2853  SmallVector<OpFoldResult> sizeValues =
2854  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2855  return b.getI64IntegerAttr(v);
2856  }));
2857  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2858  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2859  return b.getI64IntegerAttr(v);
2860  }));
2861  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2862 }
2863 
2864 // Build a SubViewOp with dynamic entries and custom result type. If the
2865 // type passed is nullptr, it is inferred.
2866 void SubViewOp::build(OpBuilder &b, OperationState &result,
2867  MemRefType resultType, Value source,
2868  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2869  ArrayRef<int64_t> strides,
2870  ArrayRef<NamedAttribute> attrs) {
2871  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2872  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2873  return b.getI64IntegerAttr(v);
2874  }));
2875  SmallVector<OpFoldResult> sizeValues =
2876  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2877  return b.getI64IntegerAttr(v);
2878  }));
2879  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2880  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2881  return b.getI64IntegerAttr(v);
2882  }));
2883  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2884  attrs);
2885 }
2886 
2887 // Build a SubViewOp with dynamic entries and custom result type. If the type
2888 // passed is nullptr, it is inferred.
2889 void SubViewOp::build(OpBuilder &b, OperationState &result,
2890  MemRefType resultType, Value source, ValueRange offsets,
2891  ValueRange sizes, ValueRange strides,
2892  ArrayRef<NamedAttribute> attrs) {
2893  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2894  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2895  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2896  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2897  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2898  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2899  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2900 }
2901 
2902 // Build a SubViewOp with dynamic entries and inferred result type.
2903 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2904  ValueRange offsets, ValueRange sizes, ValueRange strides,
2905  ArrayRef<NamedAttribute> attrs) {
2906  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2907 }
2908 
2909 /// For ViewLikeOpInterface.
2910 Value SubViewOp::getViewSource() { return getSource(); }
2911 
2912 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2913 /// static value).
2914 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2915  int64_t t1Offset, t2Offset;
2916  SmallVector<int64_t> t1Strides, t2Strides;
2917  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2918  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2919  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2920 }
2921 
2922 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2923 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2924 /// marked as dropped in `droppedDims`.
2925 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2926  const llvm::SmallBitVector &droppedDims) {
2927  assert(size_t(t1.getRank()) == droppedDims.size() &&
2928  "incorrect number of bits");
2929  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2930  "incorrect number of dropped dims");
2931  int64_t t1Offset, t2Offset;
2932  SmallVector<int64_t> t1Strides, t2Strides;
2933  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2934  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2935  if (failed(res1) || failed(res2))
2936  return false;
2937  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2938  if (droppedDims[i])
2939  continue;
2940  if (t1Strides[i] != t2Strides[j])
2941  return false;
2942  ++j;
2943  }
2944  return true;
2945 }
2946 
2948  Operation *op, Type expectedType) {
2949  auto memrefType = llvm::cast<ShapedType>(expectedType);
2950  switch (result) {
2952  return success();
2954  return op->emitError("expected result rank to be smaller or equal to ")
2955  << "the source rank. ";
2957  return op->emitError("expected result type to be ")
2958  << expectedType
2959  << " or a rank-reduced version. (mismatch of result sizes) ";
2961  return op->emitError("expected result element type to be ")
2962  << memrefType.getElementType();
2964  return op->emitError("expected result and source memory spaces to match.");
2966  return op->emitError("expected result type to be ")
2967  << expectedType
2968  << " or a rank-reduced version. (mismatch of result layout) ";
2969  }
2970  llvm_unreachable("unexpected subview verification result");
2971 }
2972 
2973 /// Verifier for SubViewOp.
2974 LogicalResult SubViewOp::verify() {
2975  MemRefType baseType = getSourceType();
2976  MemRefType subViewType = getType();
2977 
2978  // The base memref and the view memref should be in the same memory space.
2979  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2980  return emitError("different memory spaces specified for base memref "
2981  "type ")
2982  << baseType << " and subview memref type " << subViewType;
2983 
2984  // Verify that the base memref type has a strided layout map.
2985  if (!isStrided(baseType))
2986  return emitError("base type ") << baseType << " is not strided";
2987 
2988  // Compute the expected result type, assuming that there are no rank
2989  // reductions.
2990  auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2991  baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2992 
2993  // Verify all properties of a shaped type: rank, element type and dimension
2994  // sizes. This takes into account potential rank reductions.
2995  auto shapedTypeVerification = isRankReducedType(
2996  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2997  if (shapedTypeVerification != SliceVerificationResult::Success)
2998  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2999 
3000  // Make sure that the memory space did not change.
3001  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3003  *this, expectedType);
3004 
3005  // Verify the offset of the layout map.
3006  if (!haveCompatibleOffsets(expectedType, subViewType))
3008  *this, expectedType);
3009 
3010  // The only thing that's left to verify now are the strides. First, compute
3011  // the unused dimensions due to rank reductions. We have to look at sizes and
3012  // strides to decide which dimensions were dropped. This function also
3013  // partially verifies strides in case of rank reductions.
3014  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3015  getMixedSizes());
3016  if (failed(unusedDims))
3018  *this, expectedType);
3019 
3020  // Strides must match.
3021  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3023  *this, expectedType);
3024 
3025  return success();
3026 }
3027 
3028 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
3029  return os << "range " << range.offset << ":" << range.size << ":"
3030  << range.stride;
3031 }
3032 
3033 /// Return the list of Range (i.e. offset, size, stride). Each Range
3034 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3035 /// with `b` at location `loc`.
3036 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3037  OpBuilder &b, Location loc) {
3038  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3039  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3040  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3042  unsigned rank = ranks[0];
3043  res.reserve(rank);
3044  for (unsigned idx = 0; idx < rank; ++idx) {
3045  Value offset =
3046  op.isDynamicOffset(idx)
3047  ? op.getDynamicOffset(idx)
3048  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
3049  Value size =
3050  op.isDynamicSize(idx)
3051  ? op.getDynamicSize(idx)
3052  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
3053  Value stride =
3054  op.isDynamicStride(idx)
3055  ? op.getDynamicStride(idx)
3056  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
3057  res.emplace_back(Range{offset, size, stride});
3058  }
3059  return res;
3060 }
3061 
3062 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3063 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3064 /// the rank of the inferred result type if `currentResultType` is lower rank
3065 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3066 /// together with the result type. In this case, it is important to compute
3067 /// the dropped dimensions using `currentSourceType` whose strides align with
3068 /// `currentResultType`.
3070  MemRefType currentResultType, MemRefType currentSourceType,
3071  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3072  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3073  auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
3074  sourceType, mixedOffsets, mixedSizes, mixedStrides));
3075  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3076  currentSourceType, currentResultType, mixedSizes);
3077  if (failed(unusedDims))
3078  return nullptr;
3079 
3080  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3081  SmallVector<int64_t> shape, strides;
3082  unsigned numDimsAfterReduction =
3083  nonRankReducedType.getRank() - unusedDims->count();
3084  shape.reserve(numDimsAfterReduction);
3085  strides.reserve(numDimsAfterReduction);
3086  for (const auto &[idx, size, stride] :
3087  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3088  nonRankReducedType.getShape(), layout.getStrides())) {
3089  if (unusedDims->test(idx))
3090  continue;
3091  shape.push_back(size);
3092  strides.push_back(stride);
3093  }
3094 
3095  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3096  StridedLayoutAttr::get(sourceType.getContext(),
3097  layout.getOffset(), strides),
3098  nonRankReducedType.getMemorySpace());
3099 }
3100 
3102  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3103  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3104  unsigned rank = memrefType.getRank();
3105  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3106  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3107  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3108  auto targetType =
3109  llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
3110  targetShape, memrefType, offsets, sizes, strides));
3111  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3112  sizes, strides);
3113 }
3114 
3115 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3116  Value value,
3117  ArrayRef<int64_t> desiredShape) {
3118  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3119  assert(sourceMemrefType && "not a ranked memref type");
3120  auto sourceShape = sourceMemrefType.getShape();
3121  if (sourceShape.equals(desiredShape))
3122  return value;
3123  auto maybeRankReductionMask =
3124  mlir::computeRankReductionMask(sourceShape, desiredShape);
3125  if (!maybeRankReductionMask)
3126  return failure();
3127  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3128 }
3129 
3130 /// Helper method to check if a `subview` operation is trivially a no-op. This
3131 /// is the case if the all offsets are zero, all strides are 1, and the source
3132 /// shape is same as the size of the subview. In such cases, the subview can
3133 /// be folded into its source.
3134 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3135  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3136  return false;
3137 
3138  auto mixedOffsets = subViewOp.getMixedOffsets();
3139  auto mixedSizes = subViewOp.getMixedSizes();
3140  auto mixedStrides = subViewOp.getMixedStrides();
3141 
3142  // Check offsets are zero.
3143  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3144  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3145  return !intValue || intValue.value() != 0;
3146  }))
3147  return false;
3148 
3149  // Check strides are one.
3150  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3151  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3152  return !intValue || intValue.value() != 1;
3153  }))
3154  return false;
3155 
3156  // Check all size values are static and matches the (static) source shape.
3157  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3158  for (const auto &size : llvm::enumerate(mixedSizes)) {
3159  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3160  if (!intValue || *intValue != sourceShape[size.index()])
3161  return false;
3162  }
3163  // All conditions met. The `SubViewOp` is foldable as a no-op.
3164  return true;
3165 }
3166 
3167 namespace {
3168 /// Pattern to rewrite a subview op with MemRefCast arguments.
3169 /// This essentially pushes memref.cast past its consuming subview when
3170 /// `canFoldIntoConsumerOp` is true.
3171 ///
3172 /// Example:
3173 /// ```
3174 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3175 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3176 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3177 /// ```
3178 /// is rewritten into:
3179 /// ```
3180 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3181 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3182 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3183 /// ```
3184 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3185 public:
3187 
3188  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3189  PatternRewriter &rewriter) const override {
3190  // Any constant operand, just return to let SubViewOpConstantFolder kick
3191  // in.
3192  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3193  return matchPattern(operand, matchConstantIndex());
3194  }))
3195  return failure();
3196 
3197  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3198  if (!castOp)
3199  return failure();
3200 
3201  if (!CastOp::canFoldIntoConsumerOp(castOp))
3202  return failure();
3203 
3204  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3205  // the MemRefCastOp source operand type to infer the result type and the
3206  // current SubViewOp source operand type to compute the dropped dimensions
3207  // if the operation is rank-reducing.
3208  auto resultType = getCanonicalSubViewResultType(
3209  subViewOp.getType(), subViewOp.getSourceType(),
3210  llvm::cast<MemRefType>(castOp.getSource().getType()),
3211  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3212  subViewOp.getMixedStrides());
3213  if (!resultType)
3214  return failure();
3215 
3216  Value newSubView = rewriter.create<SubViewOp>(
3217  subViewOp.getLoc(), resultType, castOp.getSource(),
3218  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3219  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3220  subViewOp.getStaticStrides());
3221  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3222  newSubView);
3223  return success();
3224  }
3225 };
3226 
3227 /// Canonicalize subview ops that are no-ops. When the source shape is not
3228 /// same as a result shape due to use of `affine_map`.
3229 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3230 public:
3232 
3233  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3234  PatternRewriter &rewriter) const override {
3235  if (!isTrivialSubViewOp(subViewOp))
3236  return failure();
3237  if (subViewOp.getSourceType() == subViewOp.getType()) {
3238  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3239  return success();
3240  }
3241  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3242  subViewOp.getSource());
3243  return success();
3244  }
3245 };
3246 } // namespace
3247 
3248 /// Return the canonical type of the result of a subview.
3250  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3251  ArrayRef<OpFoldResult> mixedSizes,
3252  ArrayRef<OpFoldResult> mixedStrides) {
3253  // Infer a memref type without taking into account any rank reductions.
3254  auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3255  mixedSizes, mixedStrides);
3256  if (!resTy)
3257  return {};
3258  MemRefType nonReducedType = cast<MemRefType>(resTy);
3259 
3260  // Directly return the non-rank reduced type if there are no dropped dims.
3261  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3262  if (droppedDims.none())
3263  return nonReducedType;
3264 
3265  // Take the strides and offset from the non-rank reduced type.
3266  auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
3267 
3268  // Drop dims from shape and strides.
3269  SmallVector<int64_t> targetShape;
3270  SmallVector<int64_t> targetStrides;
3271  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3272  if (droppedDims.test(i))
3273  continue;
3274  targetStrides.push_back(nonReducedStrides[i]);
3275  targetShape.push_back(nonReducedType.getDimSize(i));
3276  }
3277 
3278  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3279  StridedLayoutAttr::get(nonReducedType.getContext(),
3280  offset, targetStrides),
3281  nonReducedType.getMemorySpace());
3282  }
3283 };
3284 
3285 /// A canonicalizer wrapper to replace SubViewOps.
3287  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3288  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3289  }
3290 };
3291 
3292 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3293  MLIRContext *context) {
3294  results
3297  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3298 }
3299 
3300 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3301  MemRefType sourceMemrefType = getSource().getType();
3302  MemRefType resultMemrefType = getResult().getType();
3303  auto resultLayout =
3304  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3305 
3306  if (resultMemrefType == sourceMemrefType &&
3307  resultMemrefType.hasStaticShape() &&
3308  (!resultLayout || resultLayout.hasStaticLayout())) {
3309  return getViewSource();
3310  }
3311 
3312  // Fold subview(subview(x)), where both subviews have the same size and the
3313  // second subview's offsets are all zero. (I.e., the second subview is a
3314  // no-op.)
3315  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3316  auto srcSizes = srcSubview.getMixedSizes();
3317  auto sizes = getMixedSizes();
3318  auto offsets = getMixedOffsets();
3319  bool allOffsetsZero = llvm::all_of(
3320  offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3321  auto strides = getMixedStrides();
3322  bool allStridesOne = llvm::all_of(
3323  strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3324  bool allSizesSame = llvm::equal(sizes, srcSizes);
3325  if (allOffsetsZero && allStridesOne && allSizesSame &&
3326  resultMemrefType == sourceMemrefType)
3327  return getViewSource();
3328  }
3329 
3330  return {};
3331 }
3332 
3333 //===----------------------------------------------------------------------===//
3334 // TransposeOp
3335 //===----------------------------------------------------------------------===//
3336 
3337 void TransposeOp::getAsmResultNames(
3338  function_ref<void(Value, StringRef)> setNameFn) {
3339  setNameFn(getResult(), "transpose");
3340 }
3341 
3342 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3343 static MemRefType inferTransposeResultType(MemRefType memRefType,
3344  AffineMap permutationMap) {
3345  auto originalSizes = memRefType.getShape();
3346  auto [originalStrides, offset] = getStridesAndOffset(memRefType);
3347  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3348 
3349  // Compute permuted sizes and strides.
3350  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3351  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3352 
3353  return MemRefType::Builder(memRefType)
3354  .setShape(sizes)
3355  .setLayout(
3356  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3357 }
3358 
3359 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3360  AffineMapAttr permutation,
3361  ArrayRef<NamedAttribute> attrs) {
3362  auto permutationMap = permutation.getValue();
3363  assert(permutationMap);
3364 
3365  auto memRefType = llvm::cast<MemRefType>(in.getType());
3366  // Compute result type.
3367  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3368 
3369  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3370  build(b, result, resultType, in, attrs);
3371 }
3372 
3373 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3375  p << " " << getIn() << " " << getPermutation();
3376  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3377  p << " : " << getIn().getType() << " to " << getType();
3378 }
3379 
3380 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3382  AffineMap permutation;
3383  MemRefType srcType, dstType;
3384  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3385  parser.parseOptionalAttrDict(result.attributes) ||
3386  parser.parseColonType(srcType) ||
3387  parser.resolveOperand(in, srcType, result.operands) ||
3388  parser.parseKeywordType("to", dstType) ||
3389  parser.addTypeToList(dstType, result.types))
3390  return failure();
3391 
3392  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3393  AffineMapAttr::get(permutation));
3394  return success();
3395 }
3396 
3397 LogicalResult TransposeOp::verify() {
3398  if (!getPermutation().isPermutation())
3399  return emitOpError("expected a permutation map");
3400  if (getPermutation().getNumDims() != getIn().getType().getRank())
3401  return emitOpError("expected a permutation map of same rank as the input");
3402 
3403  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3404  auto resultType = llvm::cast<MemRefType>(getType());
3405  auto canonicalResultType = canonicalizeStridedLayout(
3406  inferTransposeResultType(srcType, getPermutation()));
3407 
3408  if (canonicalizeStridedLayout(resultType) != canonicalResultType)
3409  return emitOpError("result type ")
3410  << resultType
3411  << " is not equivalent to the canonical transposed input type "
3412  << canonicalResultType;
3413  return success();
3414 }
3415 
3416 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3417  // First check for identity permutation, we can fold it away if input and
3418  // result types are identical already.
3419  if (getPermutation().isIdentity() && getType() == getIn().getType())
3420  return getIn();
3421  // Fold two consecutive memref.transpose Ops into one by composing their
3422  // permutation maps.
3423  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3424  AffineMap composedPermutation =
3425  getPermutation().compose(otherTransposeOp.getPermutation());
3426  getInMutable().assign(otherTransposeOp.getIn());
3427  setPermutation(composedPermutation);
3428  return getResult();
3429  }
3430  return {};
3431 }
3432 
3433 //===----------------------------------------------------------------------===//
3434 // ViewOp
3435 //===----------------------------------------------------------------------===//
3436 
3437 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3438  setNameFn(getResult(), "view");
3439 }
3440 
3441 LogicalResult ViewOp::verify() {
3442  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3443  auto viewType = getType();
3444 
3445  // The base memref should have identity layout map (or none).
3446  if (!baseType.getLayout().isIdentity())
3447  return emitError("unsupported map for base memref type ") << baseType;
3448 
3449  // The result memref should have identity layout map (or none).
3450  if (!viewType.getLayout().isIdentity())
3451  return emitError("unsupported map for result memref type ") << viewType;
3452 
3453  // The base memref and the view memref should be in the same memory space.
3454  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3455  return emitError("different memory spaces specified for base memref "
3456  "type ")
3457  << baseType << " and view memref type " << viewType;
3458 
3459  // Verify that we have the correct number of sizes for the result type.
3460  unsigned numDynamicDims = viewType.getNumDynamicDims();
3461  if (getSizes().size() != numDynamicDims)
3462  return emitError("incorrect number of size operands for type ") << viewType;
3463 
3464  return success();
3465 }
3466 
3467 Value ViewOp::getViewSource() { return getSource(); }
3468 
3469 namespace {
3470 
3471 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3473 
3474  LogicalResult matchAndRewrite(ViewOp viewOp,
3475  PatternRewriter &rewriter) const override {
3476  // Return if none of the operands are constants.
3477  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3478  return matchPattern(operand, matchConstantIndex());
3479  }))
3480  return failure();
3481 
3482  // Get result memref type.
3483  auto memrefType = viewOp.getType();
3484 
3485  // Get offset from old memref view type 'memRefType'.
3486  int64_t oldOffset;
3487  SmallVector<int64_t, 4> oldStrides;
3488  if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3489  return failure();
3490  assert(oldOffset == 0 && "Expected 0 offset");
3491 
3492  SmallVector<Value, 4> newOperands;
3493 
3494  // Offset cannot be folded into result type.
3495 
3496  // Fold any dynamic dim operands which are produced by a constant.
3497  SmallVector<int64_t, 4> newShapeConstants;
3498  newShapeConstants.reserve(memrefType.getRank());
3499 
3500  unsigned dynamicDimPos = 0;
3501  unsigned rank = memrefType.getRank();
3502  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3503  int64_t dimSize = memrefType.getDimSize(dim);
3504  // If this is already static dimension, keep it.
3505  if (!ShapedType::isDynamic(dimSize)) {
3506  newShapeConstants.push_back(dimSize);
3507  continue;
3508  }
3509  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3510  if (auto constantIndexOp =
3511  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3512  // Dynamic shape dimension will be folded.
3513  newShapeConstants.push_back(constantIndexOp.value());
3514  } else {
3515  // Dynamic shape dimension not folded; copy operand from old memref.
3516  newShapeConstants.push_back(dimSize);
3517  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3518  }
3519  dynamicDimPos++;
3520  }
3521 
3522  // Create new memref type with constant folded dims.
3523  MemRefType newMemRefType =
3524  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3525  // Nothing new, don't fold.
3526  if (newMemRefType == memrefType)
3527  return failure();
3528 
3529  // Create new ViewOp.
3530  auto newViewOp = rewriter.create<ViewOp>(
3531  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3532  viewOp.getByteShift(), newOperands);
3533  // Insert a cast so we have the same type as the old memref type.
3534  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3535  return success();
3536  }
3537 };
3538 
3539 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3541 
3542  LogicalResult matchAndRewrite(ViewOp viewOp,
3543  PatternRewriter &rewriter) const override {
3544  Value memrefOperand = viewOp.getOperand(0);
3545  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3546  if (!memrefCastOp)
3547  return failure();
3548  Value allocOperand = memrefCastOp.getOperand();
3549  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3550  if (!allocOp)
3551  return failure();
3552  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3553  viewOp.getByteShift(),
3554  viewOp.getSizes());
3555  return success();
3556  }
3557 };
3558 
3559 } // namespace
3560 
3561 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3562  MLIRContext *context) {
3563  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3564 }
3565 
3566 //===----------------------------------------------------------------------===//
3567 // AtomicRMWOp
3568 //===----------------------------------------------------------------------===//
3569 
3570 LogicalResult AtomicRMWOp::verify() {
3571  if (getMemRefType().getRank() != getNumOperands() - 2)
3572  return emitOpError(
3573  "expects the number of subscripts to be equal to memref rank");
3574  switch (getKind()) {
3575  case arith::AtomicRMWKind::addf:
3576  case arith::AtomicRMWKind::maximumf:
3577  case arith::AtomicRMWKind::minimumf:
3578  case arith::AtomicRMWKind::mulf:
3579  if (!llvm::isa<FloatType>(getValue().getType()))
3580  return emitOpError() << "with kind '"
3581  << arith::stringifyAtomicRMWKind(getKind())
3582  << "' expects a floating-point type";
3583  break;
3584  case arith::AtomicRMWKind::addi:
3585  case arith::AtomicRMWKind::maxs:
3586  case arith::AtomicRMWKind::maxu:
3587  case arith::AtomicRMWKind::mins:
3588  case arith::AtomicRMWKind::minu:
3589  case arith::AtomicRMWKind::muli:
3590  case arith::AtomicRMWKind::ori:
3591  case arith::AtomicRMWKind::andi:
3592  if (!llvm::isa<IntegerType>(getValue().getType()))
3593  return emitOpError() << "with kind '"
3594  << arith::stringifyAtomicRMWKind(getKind())
3595  << "' expects an integer type";
3596  break;
3597  default:
3598  break;
3599  }
3600  return success();
3601 }
3602 
3603 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3604  /// atomicrmw(memrefcast) -> atomicrmw
3605  if (succeeded(foldMemRefCast(*this, getValue())))
3606  return getResult();
3607  return OpFoldResult();
3608 }
3609 
3610 //===----------------------------------------------------------------------===//
3611 // TableGen'd op method definitions
3612 //===----------------------------------------------------------------------===//
3613 
3614 #define GET_OP_CLASSES
3615 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:71
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
Definition: MemRefOps.cpp:163
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
Definition: MemRefOps.cpp:115
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1551
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2131
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:449
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3343
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:430
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2914
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1398
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3069
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
Definition: MemRefOps.cpp:2947
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
Definition: MemRefOps.cpp:176
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1565
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
Definition: MemRefOps.cpp:155
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:938
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3134
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:472
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2925
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:923
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2224
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2419
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:201
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:136
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:149
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:95
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:213
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:234
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:224
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:227
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:59
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:67
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3101
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:318
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:522
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:387
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3036
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:519
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:522
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:479
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:482
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2587
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3286
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3287
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3249
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3250
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.