MLIR  20.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 /// Materialize a single constant operation from a given attribute value with
30 /// the desired resultant type.
32  Attribute value, Type type,
33  Location loc) {
34  return arith::ConstantOp::materialize(builder, value, type, loc);
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Common canonicalization pattern support logic
39 //===----------------------------------------------------------------------===//
40 
41 /// This is a common class used for patterns of the form
42 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
43 /// into the root operation directly.
44 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
45  bool folded = false;
46  for (OpOperand &operand : op->getOpOperands()) {
47  auto cast = operand.get().getDefiningOp<CastOp>();
48  if (cast && operand.get() != inner &&
49  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50  operand.set(cast.getOperand());
51  folded = true;
52  }
53  }
54  return success(folded);
55 }
56 
57 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
58 /// type.
60  if (auto memref = llvm::dyn_cast<MemRefType>(type))
61  return RankedTensorType::get(memref.getShape(), memref.getElementType());
62  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
63  return UnrankedTensorType::get(memref.getElementType());
64  return NoneType::get(type.getContext());
65 }
66 
68  int64_t dim) {
69  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that infers the constant values from a list of \p values,
91 /// a \p memRefTy, and another helper function \p getAttributes.
92 /// The inferred constant values replace the related `OpFoldResult` in
93 /// \p values.
94 ///
95 /// \note This function shouldn't be used directly, instead, use the
96 /// `getConstifiedMixedXXX` methods from the related operations.
97 ///
98 /// \p getAttributes retuns a list of potentially constant values, as determined
99 /// by \p isDynamic, from the given \p memRefTy. The returned list must have as
100 /// many elements as \p values or be empty.
101 ///
102 /// E.g., consider the following example:
103 /// ```
104 /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
105 /// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
106 /// ```
107 /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
108 /// Now using this helper function with:
109 /// - `values == [2, %dyn_stride]`,
110 /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
111 /// - `getAttributes == getConstantStrides` (i.e., a wrapper around
112 /// `getStridesAndOffset`), and
113 /// - `isDynamic == ShapedType::isDynamic`
114 /// Will yield: `values == [2, 1]`
116  SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
117  MLIRContext *ctxt,
118  llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
119  llvm::function_ref<bool(int64_t)> isDynamic) {
120  SmallVector<int64_t> constValues = getAttributes(memRefTy);
121  Builder builder(ctxt);
122  for (const auto &it : llvm::enumerate(constValues)) {
123  int64_t constValue = it.value();
124  if (!isDynamic(constValue))
125  values[it.index()] = builder.getIndexAttr(constValue);
126  }
127  for (OpFoldResult &ofr : values) {
128  if (ofr.is<Attribute>()) {
129  // FIXME: We shouldn't need to do that, but right now, the static indices
130  // are created with the wrong type: `i64` instead of `index`.
131  // As a result, if we were to keep the attribute as is, we may fail to see
132  // that two attributes are equal because one would have the i64 type and
133  // the other the index type.
134  // The alternative would be to create constant indices with getI64Attr in
135  // this and the previous loop, but it doesn't logically make sense (we are
136  // dealing with indices here) and would only strenghten the inconsistency
137  // around how static indices are created (some places use getI64Attr,
138  // others use getIndexAttr).
139  // The workaround here is to stick to the IndexAttr type for all the
140  // values, hence we recreate the attribute even when it is already static
141  // to make sure the type is consistent.
142  ofr = builder.getIndexAttr(
143  llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt());
144  continue;
145  }
146  std::optional<int64_t> maybeConstant =
147  getConstantIntValue(ofr.get<Value>());
148  if (maybeConstant)
149  ofr = builder.getIndexAttr(*maybeConstant);
150  }
151 }
152 
153 /// Wrapper around `getShape` that conforms to the function signature
154 /// expected for `getAttributes` in `constifyIndexValues`.
155 static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
156  ArrayRef<int64_t> sizes = memRefTy.getShape();
157  return SmallVector<int64_t>(sizes);
158 }
159 
160 /// Wrapper around `getStridesAndOffset` that returns only the offset and
161 /// conforms to the function signature expected for `getAttributes` in
162 /// `constifyIndexValues`.
163 static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
164  SmallVector<int64_t> strides;
165  int64_t offset;
166  LogicalResult hasStaticInformation =
167  getStridesAndOffset(memrefType, strides, offset);
168  if (failed(hasStaticInformation))
169  return SmallVector<int64_t>();
170  return SmallVector<int64_t>(1, offset);
171 }
172 
173 /// Wrapper around `getStridesAndOffset` that returns only the strides and
174 /// conforms to the function signature expected for `getAttributes` in
175 /// `constifyIndexValues`.
176 static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
177  SmallVector<int64_t> strides;
178  int64_t offset;
179  LogicalResult hasStaticInformation =
180  getStridesAndOffset(memrefType, strides, offset);
181  if (failed(hasStaticInformation))
182  return SmallVector<int64_t>();
183  return strides;
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // AllocOp / AllocaOp
188 //===----------------------------------------------------------------------===//
189 
190 void AllocOp::getAsmResultNames(
191  function_ref<void(Value, StringRef)> setNameFn) {
192  setNameFn(getResult(), "alloc");
193 }
194 
195 void AllocaOp::getAsmResultNames(
196  function_ref<void(Value, StringRef)> setNameFn) {
197  setNameFn(getResult(), "alloca");
198 }
199 
200 template <typename AllocLikeOp>
201 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
202  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
203  "applies to only alloc or alloca");
204  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
205  if (!memRefType)
206  return op.emitOpError("result must be a memref");
207 
208  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
209  return op.emitOpError("dimension operand count does not equal memref "
210  "dynamic dimension count");
211 
212  unsigned numSymbols = 0;
213  if (!memRefType.getLayout().isIdentity())
214  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
215  if (op.getSymbolOperands().size() != numSymbols)
216  return op.emitOpError("symbol operand count does not equal memref symbol "
217  "count: expected ")
218  << numSymbols << ", got " << op.getSymbolOperands().size();
219 
220  return success();
221 }
222 
223 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
224 
225 LogicalResult AllocaOp::verify() {
226  // An alloca op needs to have an ancestor with an allocation scope trait.
227  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
228  return emitOpError(
229  "requires an ancestor op with AutomaticAllocationScope trait");
230 
231  return verifyAllocLikeOp(*this);
232 }
233 
234 namespace {
235 /// Fold constant dimensions into an alloc like operation.
236 template <typename AllocLikeOp>
237 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
239 
240  LogicalResult matchAndRewrite(AllocLikeOp alloc,
241  PatternRewriter &rewriter) const override {
242  // Check to see if any dimensions operands are constants. If so, we can
243  // substitute and drop them.
244  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
245  APInt constSizeArg;
246  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
247  return false;
248  return constSizeArg.isNonNegative();
249  }))
250  return failure();
251 
252  auto memrefType = alloc.getType();
253 
254  // Ok, we have one or more constant operands. Collect the non-constant ones
255  // and keep track of the resultant memref type to build.
256  SmallVector<int64_t, 4> newShapeConstants;
257  newShapeConstants.reserve(memrefType.getRank());
258  SmallVector<Value, 4> dynamicSizes;
259 
260  unsigned dynamicDimPos = 0;
261  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
262  int64_t dimSize = memrefType.getDimSize(dim);
263  // If this is already static dimension, keep it.
264  if (!ShapedType::isDynamic(dimSize)) {
265  newShapeConstants.push_back(dimSize);
266  continue;
267  }
268  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
269  APInt constSizeArg;
270  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
271  constSizeArg.isNonNegative()) {
272  // Dynamic shape dimension will be folded.
273  newShapeConstants.push_back(constSizeArg.getZExtValue());
274  } else {
275  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
276  newShapeConstants.push_back(ShapedType::kDynamic);
277  dynamicSizes.push_back(dynamicSize);
278  }
279  dynamicDimPos++;
280  }
281 
282  // Create new memref type (which will have fewer dynamic dimensions).
283  MemRefType newMemRefType =
284  MemRefType::Builder(memrefType).setShape(newShapeConstants);
285  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
286 
287  // Create and insert the alloc op for the new memref.
288  auto newAlloc = rewriter.create<AllocLikeOp>(
289  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
290  alloc.getAlignmentAttr());
291  // Insert a cast so we have the same type as the old alloc.
292  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
293  return success();
294  }
295 };
296 
297 /// Fold alloc operations with no users or only store and dealloc uses.
298 template <typename T>
299 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
301 
302  LogicalResult matchAndRewrite(T alloc,
303  PatternRewriter &rewriter) const override {
304  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
305  if (auto storeOp = dyn_cast<StoreOp>(op))
306  return storeOp.getValue() == alloc;
307  return !isa<DeallocOp>(op);
308  }))
309  return failure();
310 
311  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
312  rewriter.eraseOp(user);
313 
314  rewriter.eraseOp(alloc);
315  return success();
316  }
317 };
318 } // namespace
319 
320 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
321  MLIRContext *context) {
322  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
323 }
324 
325 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
326  MLIRContext *context) {
327  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
328  context);
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // ReallocOp
333 //===----------------------------------------------------------------------===//
334 
335 LogicalResult ReallocOp::verify() {
336  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
337  MemRefType resultType = getType();
338 
339  // The source memref should have identity layout (or none).
340  if (!sourceType.getLayout().isIdentity())
341  return emitError("unsupported layout for source memref type ")
342  << sourceType;
343 
344  // The result memref should have identity layout (or none).
345  if (!resultType.getLayout().isIdentity())
346  return emitError("unsupported layout for result memref type ")
347  << resultType;
348 
349  // The source memref and the result memref should be in the same memory space.
350  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
351  return emitError("different memory spaces specified for source memref "
352  "type ")
353  << sourceType << " and result memref type " << resultType;
354 
355  // The source memref and the result memref should have the same element type.
356  if (sourceType.getElementType() != resultType.getElementType())
357  return emitError("different element types specified for source memref "
358  "type ")
359  << sourceType << " and result memref type " << resultType;
360 
361  // Verify that we have the dynamic dimension operand when it is needed.
362  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
363  return emitError("missing dimension operand for result type ")
364  << resultType;
365  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
366  return emitError("unnecessary dimension operand for result type ")
367  << resultType;
368 
369  return success();
370 }
371 
372 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
373  MLIRContext *context) {
374  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // AllocaScopeOp
379 //===----------------------------------------------------------------------===//
380 
382  bool printBlockTerminators = false;
383 
384  p << ' ';
385  if (!getResults().empty()) {
386  p << " -> (" << getResultTypes() << ")";
387  printBlockTerminators = true;
388  }
389  p << ' ';
390  p.printRegion(getBodyRegion(),
391  /*printEntryBlockArgs=*/false,
392  /*printBlockTerminators=*/printBlockTerminators);
393  p.printOptionalAttrDict((*this)->getAttrs());
394 }
395 
396 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
397  // Create a region for the body.
398  result.regions.reserve(1);
399  Region *bodyRegion = result.addRegion();
400 
401  // Parse optional results type list.
402  if (parser.parseOptionalArrowTypeList(result.types))
403  return failure();
404 
405  // Parse the body region.
406  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
407  return failure();
408  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
409  result.location);
410 
411  // Parse the optional attribute list.
412  if (parser.parseOptionalAttrDict(result.attributes))
413  return failure();
414 
415  return success();
416 }
417 
418 void AllocaScopeOp::getSuccessorRegions(
420  if (!point.isParent()) {
421  regions.push_back(RegionSuccessor(getResults()));
422  return;
423  }
424 
425  regions.push_back(RegionSuccessor(&getBodyRegion()));
426 }
427 
428 /// Given an operation, return whether this op is guaranteed to
429 /// allocate an AutomaticAllocationScopeResource
431  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
432  if (!interface)
433  return false;
434  for (auto res : op->getResults()) {
435  if (auto effect =
436  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
437  if (isa<SideEffects::AutomaticAllocationScopeResource>(
438  effect->getResource()))
439  return true;
440  }
441  }
442  return false;
443 }
444 
445 /// Given an operation, return whether this op itself could
446 /// allocate an AutomaticAllocationScopeResource. Note that
447 /// this will not check whether an operation contained within
448 /// the op can allocate.
450  // This op itself doesn't create a stack allocation,
451  // the inner allocation should be handled separately.
453  return false;
454  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
455  if (!interface)
456  return true;
457  for (auto res : op->getResults()) {
458  if (auto effect =
459  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
460  if (isa<SideEffects::AutomaticAllocationScopeResource>(
461  effect->getResource()))
462  return true;
463  }
464  }
465  return false;
466 }
467 
468 /// Return whether this op is the last non terminating op
469 /// in a region. That is to say, it is in a one-block region
470 /// and is only followed by a terminator. This prevents
471 /// extending the lifetime of allocations.
473  return op->getNextNode() == op->getBlock()->getTerminator() &&
474  op->getParentRegion()->getBlocks().size() == 1;
475 }
476 
477 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
478 /// or it contains no allocation.
479 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
481 
482  LogicalResult matchAndRewrite(AllocaScopeOp op,
483  PatternRewriter &rewriter) const override {
484  bool hasPotentialAlloca =
485  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
486  if (alloc == op)
487  return WalkResult::advance();
489  return WalkResult::interrupt();
490  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
491  return WalkResult::skip();
492  return WalkResult::advance();
493  }).wasInterrupted();
494 
495  // If this contains no potential allocation, it is always legal to
496  // inline. Otherwise, consider two conditions:
497  if (hasPotentialAlloca) {
498  // If the parent isn't an allocation scope, or we are not the last
499  // non-terminator op in the parent, we will extend the lifetime.
501  return failure();
502  if (!lastNonTerminatorInRegion(op))
503  return failure();
504  }
505 
506  Block *block = &op.getRegion().front();
507  Operation *terminator = block->getTerminator();
508  ValueRange results = terminator->getOperands();
509  rewriter.inlineBlockBefore(block, op);
510  rewriter.replaceOp(op, results);
511  rewriter.eraseOp(terminator);
512  return success();
513  }
514 };
515 
516 /// Move allocations into an allocation scope, if it is legal to
517 /// move them (e.g. their operands are available at the location
518 /// the op would be moved to).
519 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
521 
522  LogicalResult matchAndRewrite(AllocaScopeOp op,
523  PatternRewriter &rewriter) const override {
524 
526  return failure();
527 
528  Operation *lastParentWithoutScope = op->getParentOp();
529 
530  if (!lastParentWithoutScope ||
531  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
532  return failure();
533 
534  // Only apply to if this is this last non-terminator
535  // op in the block (lest lifetime be extended) of a one
536  // block region
537  if (!lastNonTerminatorInRegion(op) ||
538  !lastNonTerminatorInRegion(lastParentWithoutScope))
539  return failure();
540 
541  while (!lastParentWithoutScope->getParentOp()
543  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
544  if (!lastParentWithoutScope ||
545  !lastNonTerminatorInRegion(lastParentWithoutScope))
546  return failure();
547  }
548  assert(lastParentWithoutScope->getParentOp()
550 
551  Region *containingRegion = nullptr;
552  for (auto &r : lastParentWithoutScope->getRegions()) {
553  if (r.isAncestor(op->getParentRegion())) {
554  assert(containingRegion == nullptr &&
555  "only one region can contain the op");
556  containingRegion = &r;
557  }
558  }
559  assert(containingRegion && "op must be contained in a region");
560 
561  SmallVector<Operation *> toHoist;
562  op->walk([&](Operation *alloc) {
564  return WalkResult::skip();
565 
566  // If any operand is not defined before the location of
567  // lastParentWithoutScope (i.e. where we would hoist to), skip.
568  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
569  return containingRegion->isAncestor(v.getParentRegion());
570  }))
571  return WalkResult::skip();
572  toHoist.push_back(alloc);
573  return WalkResult::advance();
574  });
575 
576  if (toHoist.empty())
577  return failure();
578  rewriter.setInsertionPoint(lastParentWithoutScope);
579  for (auto *op : toHoist) {
580  auto *cloned = rewriter.clone(*op);
581  rewriter.replaceOp(op, cloned->getResults());
582  }
583  return success();
584  }
585 };
586 
587 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
588  MLIRContext *context) {
589  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // AssumeAlignmentOp
594 //===----------------------------------------------------------------------===//
595 
596 LogicalResult AssumeAlignmentOp::verify() {
597  if (!llvm::isPowerOf2_32(getAlignment()))
598  return emitOpError("alignment must be power of 2");
599  return success();
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // CastOp
604 //===----------------------------------------------------------------------===//
605 
606 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
607  setNameFn(getResult(), "cast");
608 }
609 
610 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
611 /// source memref. This is useful to fold a memref.cast into a consuming op
612 /// and implement canonicalization patterns for ops in different dialects that
613 /// may consume the results of memref.cast operations. Such foldable memref.cast
614 /// operations are typically inserted as `view` and `subview` ops are
615 /// canonicalized, to preserve the type compatibility of their uses.
616 ///
617 /// Returns true when all conditions are met:
618 /// 1. source and result are ranked memrefs with strided semantics and same
619 /// element type and rank.
620 /// 2. each of the source's size, offset or stride has more static information
621 /// than the corresponding result's size, offset or stride.
622 ///
623 /// Example 1:
624 /// ```mlir
625 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
626 /// %2 = consumer %1 ... : memref<?x?xf32> ...
627 /// ```
628 ///
629 /// may fold into:
630 ///
631 /// ```mlir
632 /// %2 = consumer %0 ... : memref<8x16xf32> ...
633 /// ```
634 ///
635 /// Example 2:
636 /// ```
637 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
638 /// to memref<?x?xf32>
639 /// consumer %1 : memref<?x?xf32> ...
640 /// ```
641 ///
642 /// may fold into:
643 ///
644 /// ```
645 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
646 /// ```
647 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
648  MemRefType sourceType =
649  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
650  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
651 
652  // Requires ranked MemRefType.
653  if (!sourceType || !resultType)
654  return false;
655 
656  // Requires same elemental type.
657  if (sourceType.getElementType() != resultType.getElementType())
658  return false;
659 
660  // Requires same rank.
661  if (sourceType.getRank() != resultType.getRank())
662  return false;
663 
664  // Only fold casts between strided memref forms.
665  int64_t sourceOffset, resultOffset;
666  SmallVector<int64_t, 4> sourceStrides, resultStrides;
667  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
668  failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
669  return false;
670 
671  // If cast is towards more static sizes along any dimension, don't fold.
672  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
673  auto ss = std::get<0>(it), st = std::get<1>(it);
674  if (ss != st)
675  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
676  return false;
677  }
678 
679  // If cast is towards more static offset along any dimension, don't fold.
680  if (sourceOffset != resultOffset)
681  if (ShapedType::isDynamic(sourceOffset) &&
682  !ShapedType::isDynamic(resultOffset))
683  return false;
684 
685  // If cast is towards more static strides along any dimension, don't fold.
686  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
687  auto ss = std::get<0>(it), st = std::get<1>(it);
688  if (ss != st)
689  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
690  return false;
691  }
692 
693  return true;
694 }
695 
696 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
697  if (inputs.size() != 1 || outputs.size() != 1)
698  return false;
699  Type a = inputs.front(), b = outputs.front();
700  auto aT = llvm::dyn_cast<MemRefType>(a);
701  auto bT = llvm::dyn_cast<MemRefType>(b);
702 
703  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
704  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
705 
706  if (aT && bT) {
707  if (aT.getElementType() != bT.getElementType())
708  return false;
709  if (aT.getLayout() != bT.getLayout()) {
710  int64_t aOffset, bOffset;
711  SmallVector<int64_t, 4> aStrides, bStrides;
712  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
713  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
714  aStrides.size() != bStrides.size())
715  return false;
716 
717  // Strides along a dimension/offset are compatible if the value in the
718  // source memref is static and the value in the target memref is the
719  // same. They are also compatible if either one is dynamic (see
720  // description of MemRefCastOp for details).
721  auto checkCompatible = [](int64_t a, int64_t b) {
722  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
723  };
724  if (!checkCompatible(aOffset, bOffset))
725  return false;
726  for (const auto &aStride : enumerate(aStrides))
727  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
728  return false;
729  }
730  if (aT.getMemorySpace() != bT.getMemorySpace())
731  return false;
732 
733  // They must have the same rank, and any specified dimensions must match.
734  if (aT.getRank() != bT.getRank())
735  return false;
736 
737  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
738  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
739  if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
740  aDim != bDim)
741  return false;
742  }
743  return true;
744  } else {
745  if (!aT && !uaT)
746  return false;
747  if (!bT && !ubT)
748  return false;
749  // Unranked to unranked casting is unsupported
750  if (uaT && ubT)
751  return false;
752 
753  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
754  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
755  if (aEltType != bEltType)
756  return false;
757 
758  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
759  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
760  return aMemSpace == bMemSpace;
761  }
762 
763  return false;
764 }
765 
766 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
767  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // CopyOp
772 //===----------------------------------------------------------------------===//
773 
774 namespace {
775 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
776 /// and element type, the cast can be skipped. Such CastOps only cast the layout
777 /// of the type.
778 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
780 
781  LogicalResult matchAndRewrite(CopyOp copyOp,
782  PatternRewriter &rewriter) const override {
783  bool modified = false;
784 
785  // Check source.
786  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
787  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
788  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
789 
790  if (fromType && toType) {
791  if (fromType.getShape() == toType.getShape() &&
792  fromType.getElementType() == toType.getElementType()) {
793  rewriter.modifyOpInPlace(copyOp, [&] {
794  copyOp.getSourceMutable().assign(castOp.getSource());
795  });
796  modified = true;
797  }
798  }
799  }
800 
801  // Check target.
802  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
803  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
804  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
805 
806  if (fromType && toType) {
807  if (fromType.getShape() == toType.getShape() &&
808  fromType.getElementType() == toType.getElementType()) {
809  rewriter.modifyOpInPlace(copyOp, [&] {
810  copyOp.getTargetMutable().assign(castOp.getSource());
811  });
812  modified = true;
813  }
814  }
815  }
816 
817  return success(modified);
818  }
819 };
820 
821 /// Fold memref.copy(%x, %x).
822 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
824 
825  LogicalResult matchAndRewrite(CopyOp copyOp,
826  PatternRewriter &rewriter) const override {
827  if (copyOp.getSource() != copyOp.getTarget())
828  return failure();
829 
830  rewriter.eraseOp(copyOp);
831  return success();
832  }
833 };
834 
835 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
837 
838  static bool isEmptyMemRef(BaseMemRefType type) {
839  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
840  }
841 
842  LogicalResult matchAndRewrite(CopyOp copyOp,
843  PatternRewriter &rewriter) const override {
844  if (isEmptyMemRef(copyOp.getSource().getType()) ||
845  isEmptyMemRef(copyOp.getTarget().getType())) {
846  rewriter.eraseOp(copyOp);
847  return success();
848  }
849 
850  return failure();
851  }
852 };
853 } // namespace
854 
855 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
856  MLIRContext *context) {
857  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
858 }
859 
860 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
862  /// copy(memrefcast) -> copy
863  bool folded = false;
864  Operation *op = *this;
865  for (OpOperand &operand : op->getOpOperands()) {
866  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
867  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
868  operand.set(castOp.getOperand());
869  folded = true;
870  }
871  }
872  return success(folded);
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // DeallocOp
877 //===----------------------------------------------------------------------===//
878 
879 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
881  /// dealloc(memrefcast) -> dealloc
882  return foldMemRefCast(*this);
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // DimOp
887 //===----------------------------------------------------------------------===//
888 
889 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
890  setNameFn(getResult(), "dim");
891 }
892 
893 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
894  int64_t index) {
895  auto loc = result.location;
896  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
897  build(builder, result, source, indexValue);
898 }
899 
900 std::optional<int64_t> DimOp::getConstantIndex() {
901  return getConstantIntValue(getIndex());
902 }
903 
904 Speculation::Speculatability DimOp::getSpeculatability() {
905  auto constantIndex = getConstantIndex();
906  if (!constantIndex)
908 
909  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
910  if (!rankedSourceType)
912 
913  if (rankedSourceType.getRank() <= constantIndex)
915 
917 }
918 
919 /// Return a map with key being elements in `vals` and data being number of
920 /// occurences of it. Use std::map, since the `vals` here are strides and the
921 /// dynamic stride value is the same as the tombstone value for
922 /// `DenseMap<int64_t>`.
923 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
924  std::map<int64_t, unsigned> numOccurences;
925  for (auto val : vals)
926  numOccurences[val]++;
927  return numOccurences;
928 }
929 
930 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
931 /// to be a subset of `originalType` with some `1` entries erased, return the
932 /// set of indices that specifies which of the entries of `originalShape` are
933 /// dropped to obtain `reducedShape`.
934 /// This accounts for cases where there are multiple unit-dims, but only a
935 /// subset of those are dropped. For MemRefTypes these can be disambiguated
936 /// using the strides. If a dimension is dropped the stride must be dropped too.
937 static FailureOr<llvm::SmallBitVector>
938 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
939  ArrayRef<OpFoldResult> sizes) {
940  llvm::SmallBitVector unusedDims(originalType.getRank());
941  if (originalType.getRank() == reducedType.getRank())
942  return unusedDims;
943 
944  for (const auto &dim : llvm::enumerate(sizes))
945  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
946  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
947  unusedDims.set(dim.index());
948 
949  // Early exit for the case where the number of unused dims matches the number
950  // of ranks reduced.
951  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
952  originalType.getRank())
953  return unusedDims;
954 
955  SmallVector<int64_t> originalStrides, candidateStrides;
956  int64_t originalOffset, candidateOffset;
957  if (failed(
958  getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
959  failed(
960  getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
961  return failure();
962 
963  // For memrefs, a dimension is truly dropped if its corresponding stride is
964  // also dropped. This is particularly important when more than one of the dims
965  // is 1. Track the number of occurences of the strides in the original type
966  // and the candidate type. For each unused dim that stride should not be
967  // present in the candidate type. Note that there could be multiple dimensions
968  // that have the same size. We dont need to exactly figure out which dim
969  // corresponds to which stride, we just need to verify that the number of
970  // reptitions of a stride in the original + number of unused dims with that
971  // stride == number of repititions of a stride in the candidate.
972  std::map<int64_t, unsigned> currUnaccountedStrides =
973  getNumOccurences(originalStrides);
974  std::map<int64_t, unsigned> candidateStridesNumOccurences =
975  getNumOccurences(candidateStrides);
976  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
977  if (!unusedDims.test(dim))
978  continue;
979  int64_t originalStride = originalStrides[dim];
980  if (currUnaccountedStrides[originalStride] >
981  candidateStridesNumOccurences[originalStride]) {
982  // This dim can be treated as dropped.
983  currUnaccountedStrides[originalStride]--;
984  continue;
985  }
986  if (currUnaccountedStrides[originalStride] ==
987  candidateStridesNumOccurences[originalStride]) {
988  // The stride for this is not dropped. Keep as is.
989  unusedDims.reset(dim);
990  continue;
991  }
992  if (currUnaccountedStrides[originalStride] <
993  candidateStridesNumOccurences[originalStride]) {
994  // This should never happen. Cant have a stride in the reduced rank type
995  // that wasnt in the original one.
996  return failure();
997  }
998  }
999 
1000  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1001  originalType.getRank())
1002  return failure();
1003  return unusedDims;
1004 }
1005 
1006 llvm::SmallBitVector SubViewOp::getDroppedDims() {
1007  MemRefType sourceType = getSourceType();
1008  MemRefType resultType = getType();
1009  FailureOr<llvm::SmallBitVector> unusedDims =
1010  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1011  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1012  return *unusedDims;
1013 }
1014 
1015 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1016  // All forms of folding require a known index.
1017  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1018  if (!index)
1019  return {};
1020 
1021  // Folding for unranked types (UnrankedMemRefType) is not supported.
1022  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1023  if (!memrefType)
1024  return {};
1025 
1026  // Out of bound indices produce undefined behavior but are still valid IR.
1027  // Don't choke on them.
1028  int64_t indexVal = index.getInt();
1029  if (indexVal < 0 || indexVal >= memrefType.getRank())
1030  return {};
1031 
1032  // Fold if the shape extent along the given index is known.
1033  if (!memrefType.isDynamicDim(index.getInt())) {
1034  Builder builder(getContext());
1035  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1036  }
1037 
1038  // The size at the given index is now known to be a dynamic size.
1039  unsigned unsignedIndex = index.getValue().getZExtValue();
1040 
1041  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1042  Operation *definingOp = getSource().getDefiningOp();
1043 
1044  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1045  return *(alloc.getDynamicSizes().begin() +
1046  memrefType.getDynamicDimIndex(unsignedIndex));
1047 
1048  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1049  return *(alloca.getDynamicSizes().begin() +
1050  memrefType.getDynamicDimIndex(unsignedIndex));
1051 
1052  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1053  return *(view.getDynamicSizes().begin() +
1054  memrefType.getDynamicDimIndex(unsignedIndex));
1055 
1056  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1057  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1058  unsigned resultIndex = 0;
1059  unsigned sourceRank = subview.getSourceType().getRank();
1060  unsigned sourceIndex = 0;
1061  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1062  if (unusedDims.test(i))
1063  continue;
1064  if (resultIndex == unsignedIndex) {
1065  sourceIndex = i;
1066  break;
1067  }
1068  resultIndex++;
1069  }
1070  assert(subview.isDynamicSize(sourceIndex) &&
1071  "expected dynamic subview size");
1072  return subview.getDynamicSize(sourceIndex);
1073  }
1074 
1075  if (auto sizeInterface =
1076  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1077  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1078  "Expected dynamic subview size");
1079  return sizeInterface.getDynamicSize(unsignedIndex);
1080  }
1081 
1082  // dim(memrefcast) -> dim
1083  if (succeeded(foldMemRefCast(*this)))
1084  return getResult();
1085 
1086  return {};
1087 }
1088 
1089 namespace {
1090 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1091 /// operand.
1092 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1094 
1095  LogicalResult matchAndRewrite(DimOp dim,
1096  PatternRewriter &rewriter) const override {
1097  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1098 
1099  if (!reshape)
1100  return rewriter.notifyMatchFailure(
1101  dim, "Dim op is not defined by a reshape op.");
1102 
1103  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1104  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1105  // cheaply check that either of the following conditions hold:
1106  // 1. dim.getIndex() is defined in the same block as reshape but before
1107  // reshape.
1108  // 2. dim.getIndex() is defined in a parent block of
1109  // reshape.
1110 
1111  // Check condition 1
1112  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1113  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1114  if (reshape->isBeforeInBlock(definingOp)) {
1115  return rewriter.notifyMatchFailure(
1116  dim,
1117  "dim.getIndex is not defined before reshape in the same block.");
1118  }
1119  } // else dim.getIndex is a block argument to reshape->getBlock and
1120  // dominates reshape
1121  } // Check condition 2
1122  else if (dim->getBlock() != reshape->getBlock() &&
1123  !dim.getIndex().getParentRegion()->isProperAncestor(
1124  reshape->getParentRegion())) {
1125  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1126  // already know dim.getIndex() dominates reshape without calling
1127  // `isProperAncestor`
1128  return rewriter.notifyMatchFailure(
1129  dim, "dim.getIndex does not dominate reshape.");
1130  }
1131 
1132  // Place the load directly after the reshape to ensure that the shape memref
1133  // was not mutated.
1134  rewriter.setInsertionPointAfter(reshape);
1135  Location loc = dim.getLoc();
1136  Value load =
1137  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1138  if (load.getType() != dim.getType())
1139  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1140  rewriter.replaceOp(dim, load);
1141  return success();
1142  }
1143 };
1144 
1145 } // namespace
1146 
1147 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1148  MLIRContext *context) {
1149  results.add<DimOfMemRefReshape>(context);
1150 }
1151 
1152 // ---------------------------------------------------------------------------
1153 // DmaStartOp
1154 // ---------------------------------------------------------------------------
1155 
1156 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1157  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1158  ValueRange destIndices, Value numElements,
1159  Value tagMemRef, ValueRange tagIndices, Value stride,
1160  Value elementsPerStride) {
1161  result.addOperands(srcMemRef);
1162  result.addOperands(srcIndices);
1163  result.addOperands(destMemRef);
1164  result.addOperands(destIndices);
1165  result.addOperands({numElements, tagMemRef});
1166  result.addOperands(tagIndices);
1167  if (stride)
1168  result.addOperands({stride, elementsPerStride});
1169 }
1170 
1172  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1173  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1174  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1175  if (isStrided())
1176  p << ", " << getStride() << ", " << getNumElementsPerStride();
1177 
1178  p.printOptionalAttrDict((*this)->getAttrs());
1179  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1180  << ", " << getTagMemRef().getType();
1181 }
1182 
1183 // Parse DmaStartOp.
1184 // Ex:
1185 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1186 // %tag[%index], %stride, %num_elt_per_stride :
1187 // : memref<3076 x f32, 0>,
1188 // memref<1024 x f32, 2>,
1189 // memref<1 x i32>
1190 //
1191 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1192  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1194  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1196  OpAsmParser::UnresolvedOperand numElementsInfo;
1197  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1200 
1201  SmallVector<Type, 3> types;
1202  auto indexType = parser.getBuilder().getIndexType();
1203 
1204  // Parse and resolve the following list of operands:
1205  // *) source memref followed by its indices (in square brackets).
1206  // *) destination memref followed by its indices (in square brackets).
1207  // *) dma size in KiB.
1208  if (parser.parseOperand(srcMemRefInfo) ||
1209  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1210  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1211  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1212  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1213  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1214  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1215  return failure();
1216 
1217  // Parse optional stride and elements per stride.
1218  if (parser.parseTrailingOperandList(strideInfo))
1219  return failure();
1220 
1221  bool isStrided = strideInfo.size() == 2;
1222  if (!strideInfo.empty() && !isStrided) {
1223  return parser.emitError(parser.getNameLoc(),
1224  "expected two stride related operands");
1225  }
1226 
1227  if (parser.parseColonTypeList(types))
1228  return failure();
1229  if (types.size() != 3)
1230  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1231 
1232  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1233  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1234  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1235  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1236  // size should be an index.
1237  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1238  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1239  // tag indices should be index.
1240  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1241  return failure();
1242 
1243  if (isStrided) {
1244  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1245  return failure();
1246  }
1247 
1248  return success();
1249 }
1250 
1251 LogicalResult DmaStartOp::verify() {
1252  unsigned numOperands = getNumOperands();
1253 
1254  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1255  // the number of elements.
1256  if (numOperands < 4)
1257  return emitOpError("expected at least 4 operands");
1258 
1259  // Check types of operands. The order of these calls is important: the later
1260  // calls rely on some type properties to compute the operand position.
1261  // 1. Source memref.
1262  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1263  return emitOpError("expected source to be of memref type");
1264  if (numOperands < getSrcMemRefRank() + 4)
1265  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1266  << " operands";
1267  if (!getSrcIndices().empty() &&
1268  !llvm::all_of(getSrcIndices().getTypes(),
1269  [](Type t) { return t.isIndex(); }))
1270  return emitOpError("expected source indices to be of index type");
1271 
1272  // 2. Destination memref.
1273  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1274  return emitOpError("expected destination to be of memref type");
1275  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1276  if (numOperands < numExpectedOperands)
1277  return emitOpError() << "expected at least " << numExpectedOperands
1278  << " operands";
1279  if (!getDstIndices().empty() &&
1280  !llvm::all_of(getDstIndices().getTypes(),
1281  [](Type t) { return t.isIndex(); }))
1282  return emitOpError("expected destination indices to be of index type");
1283 
1284  // 3. Number of elements.
1285  if (!getNumElements().getType().isIndex())
1286  return emitOpError("expected num elements to be of index type");
1287 
1288  // 4. Tag memref.
1289  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1290  return emitOpError("expected tag to be of memref type");
1291  numExpectedOperands += getTagMemRefRank();
1292  if (numOperands < numExpectedOperands)
1293  return emitOpError() << "expected at least " << numExpectedOperands
1294  << " operands";
1295  if (!getTagIndices().empty() &&
1296  !llvm::all_of(getTagIndices().getTypes(),
1297  [](Type t) { return t.isIndex(); }))
1298  return emitOpError("expected tag indices to be of index type");
1299 
1300  // Optional stride-related operands must be either both present or both
1301  // absent.
1302  if (numOperands != numExpectedOperands &&
1303  numOperands != numExpectedOperands + 2)
1304  return emitOpError("incorrect number of operands");
1305 
1306  // 5. Strides.
1307  if (isStrided()) {
1308  if (!getStride().getType().isIndex() ||
1309  !getNumElementsPerStride().getType().isIndex())
1310  return emitOpError(
1311  "expected stride and num elements per stride to be of type index");
1312  }
1313 
1314  return success();
1315 }
1316 
1317 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1318  SmallVectorImpl<OpFoldResult> &results) {
1319  /// dma_start(memrefcast) -> dma_start
1320  return foldMemRefCast(*this);
1321 }
1322 
1323 // ---------------------------------------------------------------------------
1324 // DmaWaitOp
1325 // ---------------------------------------------------------------------------
1326 
1327 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1328  SmallVectorImpl<OpFoldResult> &results) {
1329  /// dma_wait(memrefcast) -> dma_wait
1330  return foldMemRefCast(*this);
1331 }
1332 
1333 LogicalResult DmaWaitOp::verify() {
1334  // Check that the number of tag indices matches the tagMemRef rank.
1335  unsigned numTagIndices = getTagIndices().size();
1336  unsigned tagMemRefRank = getTagMemRefRank();
1337  if (numTagIndices != tagMemRefRank)
1338  return emitOpError() << "expected tagIndices to have the same number of "
1339  "elements as the tagMemRef rank, expected "
1340  << tagMemRefRank << ", but got " << numTagIndices;
1341  return success();
1342 }
1343 
1344 //===----------------------------------------------------------------------===//
1345 // ExtractAlignedPointerAsIndexOp
1346 //===----------------------------------------------------------------------===//
1347 
1348 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1349  function_ref<void(Value, StringRef)> setNameFn) {
1350  setNameFn(getResult(), "intptr");
1351 }
1352 
1353 //===----------------------------------------------------------------------===//
1354 // ExtractStridedMetadataOp
1355 //===----------------------------------------------------------------------===//
1356 
1357 /// The number and type of the results are inferred from the
1358 /// shape of the source.
1359 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1360  MLIRContext *context, std::optional<Location> location,
1361  ExtractStridedMetadataOp::Adaptor adaptor,
1362  SmallVectorImpl<Type> &inferredReturnTypes) {
1363  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1364  if (!sourceType)
1365  return failure();
1366 
1367  unsigned sourceRank = sourceType.getRank();
1368  IndexType indexType = IndexType::get(context);
1369  auto memrefType =
1370  MemRefType::get({}, sourceType.getElementType(),
1371  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1372  // Base.
1373  inferredReturnTypes.push_back(memrefType);
1374  // Offset.
1375  inferredReturnTypes.push_back(indexType);
1376  // Sizes and strides.
1377  for (unsigned i = 0; i < sourceRank * 2; ++i)
1378  inferredReturnTypes.push_back(indexType);
1379  return success();
1380 }
1381 
1382 void ExtractStridedMetadataOp::getAsmResultNames(
1383  function_ref<void(Value, StringRef)> setNameFn) {
1384  setNameFn(getBaseBuffer(), "base_buffer");
1385  setNameFn(getOffset(), "offset");
1386  // For multi-result to work properly with pretty names and packed syntax `x:3`
1387  // we can only give a pretty name to the first value in the pack.
1388  if (!getSizes().empty()) {
1389  setNameFn(getSizes().front(), "sizes");
1390  setNameFn(getStrides().front(), "strides");
1391  }
1392 }
1393 
1394 /// Helper function to perform the replacement of all constant uses of `values`
1395 /// by a materialized constant extracted from `maybeConstants`.
1396 /// `values` and `maybeConstants` are expected to have the same size.
1397 template <typename Container>
1398 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1399  Container values,
1400  ArrayRef<OpFoldResult> maybeConstants) {
1401  assert(values.size() == maybeConstants.size() &&
1402  " expected values and maybeConstants of the same size");
1403  bool atLeastOneReplacement = false;
1404  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1405  // Don't materialize a constant if there are no uses: this would indice
1406  // infinite loops in the driver.
1407  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1408  continue;
1409  assert(maybeConstant.template is<Attribute>() &&
1410  "The constified value should be either unchanged (i.e., == result) "
1411  "or a constant");
1412  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1413  loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
1414  .getInt());
1415  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1416  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1417  // yet.
1418  op->replaceUsesOfWith(result, constantVal);
1419  atLeastOneReplacement = true;
1420  }
1421  }
1422  return atLeastOneReplacement;
1423 }
1424 
1425 LogicalResult
1426 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1427  SmallVectorImpl<OpFoldResult> &results) {
1428  OpBuilder builder(*this);
1429 
1430  bool atLeastOneReplacement = replaceConstantUsesOf(
1431  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1432  getConstifiedMixedOffset());
1433  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1434  getConstifiedMixedSizes());
1435  atLeastOneReplacement |= replaceConstantUsesOf(
1436  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1437 
1438  return success(atLeastOneReplacement);
1439 }
1440 
1441 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1442  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1443  constifyIndexValues(values, getSource().getType(), getContext(),
1444  getConstantSizes, ShapedType::isDynamic);
1445  return values;
1446 }
1447 
1449 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1450  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1451  constifyIndexValues(values, getSource().getType(), getContext(),
1452  getConstantStrides, ShapedType::isDynamic);
1453  return values;
1454 }
1455 
1456 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1457  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1458  SmallVector<OpFoldResult> values(1, offsetOfr);
1459  constifyIndexValues(values, getSource().getType(), getContext(),
1460  getConstantOffset, ShapedType::isDynamic);
1461  return values[0];
1462 }
1463 
1464 //===----------------------------------------------------------------------===//
1465 // GenericAtomicRMWOp
1466 //===----------------------------------------------------------------------===//
1467 
1468 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1469  Value memref, ValueRange ivs) {
1470  OpBuilder::InsertionGuard g(builder);
1471  result.addOperands(memref);
1472  result.addOperands(ivs);
1473 
1474  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1475  Type elementType = memrefType.getElementType();
1476  result.addTypes(elementType);
1477 
1478  Region *bodyRegion = result.addRegion();
1479  builder.createBlock(bodyRegion);
1480  bodyRegion->addArgument(elementType, memref.getLoc());
1481  }
1482 }
1483 
1484 LogicalResult GenericAtomicRMWOp::verify() {
1485  auto &body = getRegion();
1486  if (body.getNumArguments() != 1)
1487  return emitOpError("expected single number of entry block arguments");
1488 
1489  if (getResult().getType() != body.getArgument(0).getType())
1490  return emitOpError("expected block argument of the same type result type");
1491 
1492  bool hasSideEffects =
1493  body.walk([&](Operation *nestedOp) {
1494  if (isMemoryEffectFree(nestedOp))
1495  return WalkResult::advance();
1496  nestedOp->emitError(
1497  "body of 'memref.generic_atomic_rmw' should contain "
1498  "only operations with no side effects");
1499  return WalkResult::interrupt();
1500  })
1501  .wasInterrupted();
1502  return hasSideEffects ? failure() : success();
1503 }
1504 
1505 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1506  OperationState &result) {
1508  Type memrefType;
1510 
1511  Type indexType = parser.getBuilder().getIndexType();
1512  if (parser.parseOperand(memref) ||
1514  parser.parseColonType(memrefType) ||
1515  parser.resolveOperand(memref, memrefType, result.operands) ||
1516  parser.resolveOperands(ivs, indexType, result.operands))
1517  return failure();
1518 
1519  Region *body = result.addRegion();
1520  if (parser.parseRegion(*body, {}) ||
1521  parser.parseOptionalAttrDict(result.attributes))
1522  return failure();
1523  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1524  return success();
1525 }
1526 
1528  p << ' ' << getMemref() << "[" << getIndices()
1529  << "] : " << getMemref().getType() << ' ';
1530  p.printRegion(getRegion());
1531  p.printOptionalAttrDict((*this)->getAttrs());
1532 }
1533 
1534 //===----------------------------------------------------------------------===//
1535 // AtomicYieldOp
1536 //===----------------------------------------------------------------------===//
1537 
1538 LogicalResult AtomicYieldOp::verify() {
1539  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1540  Type resultType = getResult().getType();
1541  if (parentType != resultType)
1542  return emitOpError() << "types mismatch between yield op: " << resultType
1543  << " and its parent: " << parentType;
1544  return success();
1545 }
1546 
1547 //===----------------------------------------------------------------------===//
1548 // GlobalOp
1549 //===----------------------------------------------------------------------===//
1550 
1552  TypeAttr type,
1553  Attribute initialValue) {
1554  p << type;
1555  if (!op.isExternal()) {
1556  p << " = ";
1557  if (op.isUninitialized())
1558  p << "uninitialized";
1559  else
1560  p.printAttributeWithoutType(initialValue);
1561  }
1562 }
1563 
1564 static ParseResult
1566  Attribute &initialValue) {
1567  Type type;
1568  if (parser.parseType(type))
1569  return failure();
1570 
1571  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1572  if (!memrefType || !memrefType.hasStaticShape())
1573  return parser.emitError(parser.getNameLoc())
1574  << "type should be static shaped memref, but got " << type;
1575  typeAttr = TypeAttr::get(type);
1576 
1577  if (parser.parseOptionalEqual())
1578  return success();
1579 
1580  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1581  initialValue = UnitAttr::get(parser.getContext());
1582  return success();
1583  }
1584 
1585  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1586  if (parser.parseAttribute(initialValue, tensorType))
1587  return failure();
1588  if (!llvm::isa<ElementsAttr>(initialValue))
1589  return parser.emitError(parser.getNameLoc())
1590  << "initial value should be a unit or elements attribute";
1591  return success();
1592 }
1593 
1594 LogicalResult GlobalOp::verify() {
1595  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1596  if (!memrefType || !memrefType.hasStaticShape())
1597  return emitOpError("type should be static shaped memref, but got ")
1598  << getType();
1599 
1600  // Verify that the initial value, if present, is either a unit attribute or
1601  // an elements attribute.
1602  if (getInitialValue().has_value()) {
1603  Attribute initValue = getInitialValue().value();
1604  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1605  return emitOpError("initial value should be a unit or elements "
1606  "attribute, but got ")
1607  << initValue;
1608 
1609  // Check that the type of the initial value is compatible with the type of
1610  // the global variable.
1611  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1612  Type initType = elementsAttr.getType();
1613  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1614  if (initType != tensorType)
1615  return emitOpError("initial value expected to be of type ")
1616  << tensorType << ", but was of type " << initType;
1617  }
1618  }
1619 
1620  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1621  uint64_t alignment = *alignAttr;
1622 
1623  if (!llvm::isPowerOf2_64(alignment))
1624  return emitError() << "alignment attribute value " << alignment
1625  << " is not a power of 2";
1626  }
1627 
1628  // TODO: verify visibility for declarations.
1629  return success();
1630 }
1631 
1632 ElementsAttr GlobalOp::getConstantInitValue() {
1633  auto initVal = getInitialValue();
1634  if (getConstant() && initVal.has_value())
1635  return llvm::cast<ElementsAttr>(initVal.value());
1636  return {};
1637 }
1638 
1639 //===----------------------------------------------------------------------===//
1640 // GetGlobalOp
1641 //===----------------------------------------------------------------------===//
1642 
1643 LogicalResult
1644 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1645  // Verify that the result type is same as the type of the referenced
1646  // memref.global op.
1647  auto global =
1648  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1649  if (!global)
1650  return emitOpError("'")
1651  << getName() << "' does not reference a valid global memref";
1652 
1653  Type resultType = getResult().getType();
1654  if (global.getType() != resultType)
1655  return emitOpError("result type ")
1656  << resultType << " does not match type " << global.getType()
1657  << " of the global memref @" << getName();
1658  return success();
1659 }
1660 
1661 //===----------------------------------------------------------------------===//
1662 // LoadOp
1663 //===----------------------------------------------------------------------===//
1664 
1665 LogicalResult LoadOp::verify() {
1666  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1667  return emitOpError("incorrect number of indices for load, expected ")
1668  << getMemRefType().getRank() << " but got " << getIndices().size();
1669  }
1670  return success();
1671 }
1672 
1673 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1674  /// load(memrefcast) -> load
1675  if (succeeded(foldMemRefCast(*this)))
1676  return getResult();
1677  return OpFoldResult();
1678 }
1679 
1680 //===----------------------------------------------------------------------===//
1681 // MemorySpaceCastOp
1682 //===----------------------------------------------------------------------===//
1683 
1684 void MemorySpaceCastOp::getAsmResultNames(
1685  function_ref<void(Value, StringRef)> setNameFn) {
1686  setNameFn(getResult(), "memspacecast");
1687 }
1688 
1689 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1690  if (inputs.size() != 1 || outputs.size() != 1)
1691  return false;
1692  Type a = inputs.front(), b = outputs.front();
1693  auto aT = llvm::dyn_cast<MemRefType>(a);
1694  auto bT = llvm::dyn_cast<MemRefType>(b);
1695 
1696  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1697  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1698 
1699  if (aT && bT) {
1700  if (aT.getElementType() != bT.getElementType())
1701  return false;
1702  if (aT.getLayout() != bT.getLayout())
1703  return false;
1704  if (aT.getShape() != bT.getShape())
1705  return false;
1706  return true;
1707  }
1708  if (uaT && ubT) {
1709  return uaT.getElementType() == ubT.getElementType();
1710  }
1711  return false;
1712 }
1713 
1714 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1715  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1716  // t2)
1717  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1718  getSourceMutable().assign(parentCast.getSource());
1719  return getResult();
1720  }
1721  return Value{};
1722 }
1723 
1724 //===----------------------------------------------------------------------===//
1725 // PrefetchOp
1726 //===----------------------------------------------------------------------===//
1727 
1729  p << " " << getMemref() << '[';
1731  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1732  p << ", locality<" << getLocalityHint();
1733  p << ">, " << (getIsDataCache() ? "data" : "instr");
1735  (*this)->getAttrs(),
1736  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1737  p << " : " << getMemRefType();
1738 }
1739 
1740 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1741  OpAsmParser::UnresolvedOperand memrefInfo;
1743  IntegerAttr localityHint;
1744  MemRefType type;
1745  StringRef readOrWrite, cacheType;
1746 
1747  auto indexTy = parser.getBuilder().getIndexType();
1748  auto i32Type = parser.getBuilder().getIntegerType(32);
1749  if (parser.parseOperand(memrefInfo) ||
1750  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1751  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1752  parser.parseComma() || parser.parseKeyword("locality") ||
1753  parser.parseLess() ||
1754  parser.parseAttribute(localityHint, i32Type, "localityHint",
1755  result.attributes) ||
1756  parser.parseGreater() || parser.parseComma() ||
1757  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1758  parser.resolveOperand(memrefInfo, type, result.operands) ||
1759  parser.resolveOperands(indexInfo, indexTy, result.operands))
1760  return failure();
1761 
1762  if (readOrWrite != "read" && readOrWrite != "write")
1763  return parser.emitError(parser.getNameLoc(),
1764  "rw specifier has to be 'read' or 'write'");
1765  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1766  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1767 
1768  if (cacheType != "data" && cacheType != "instr")
1769  return parser.emitError(parser.getNameLoc(),
1770  "cache type has to be 'data' or 'instr'");
1771 
1772  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1773  parser.getBuilder().getBoolAttr(cacheType == "data"));
1774 
1775  return success();
1776 }
1777 
1778 LogicalResult PrefetchOp::verify() {
1779  if (getNumOperands() != 1 + getMemRefType().getRank())
1780  return emitOpError("too few indices");
1781 
1782  return success();
1783 }
1784 
1785 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1786  SmallVectorImpl<OpFoldResult> &results) {
1787  // prefetch(memrefcast) -> prefetch
1788  return foldMemRefCast(*this);
1789 }
1790 
1791 //===----------------------------------------------------------------------===//
1792 // RankOp
1793 //===----------------------------------------------------------------------===//
1794 
1795 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1796  // Constant fold rank when the rank of the operand is known.
1797  auto type = getOperand().getType();
1798  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1799  if (shapedType && shapedType.hasRank())
1800  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1801  return IntegerAttr();
1802 }
1803 
1804 //===----------------------------------------------------------------------===//
1805 // ReinterpretCastOp
1806 //===----------------------------------------------------------------------===//
1807 
1808 void ReinterpretCastOp::getAsmResultNames(
1809  function_ref<void(Value, StringRef)> setNameFn) {
1810  setNameFn(getResult(), "reinterpret_cast");
1811 }
1812 
1813 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1814 /// `staticSizes` and `staticStrides` are automatically filled with
1815 /// source-memref-rank sentinel values that encode dynamic entries.
1816 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1817  MemRefType resultType, Value source,
1818  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1819  ArrayRef<OpFoldResult> strides,
1820  ArrayRef<NamedAttribute> attrs) {
1821  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1822  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1823  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1824  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1825  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1826  result.addAttributes(attrs);
1827  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1828  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1829  b.getDenseI64ArrayAttr(staticSizes),
1830  b.getDenseI64ArrayAttr(staticStrides));
1831 }
1832 
1833 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1834  Value source, OpFoldResult offset,
1835  ArrayRef<OpFoldResult> sizes,
1836  ArrayRef<OpFoldResult> strides,
1837  ArrayRef<NamedAttribute> attrs) {
1838  auto sourceType = cast<BaseMemRefType>(source.getType());
1839  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1840  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1841  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1842  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1843  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1844  auto stridedLayout = StridedLayoutAttr::get(
1845  b.getContext(), staticOffsets.front(), staticStrides);
1846  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1847  stridedLayout, sourceType.getMemorySpace());
1848  build(b, result, resultType, source, offset, sizes, strides, attrs);
1849 }
1850 
1851 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1852  MemRefType resultType, Value source,
1853  int64_t offset, ArrayRef<int64_t> sizes,
1854  ArrayRef<int64_t> strides,
1855  ArrayRef<NamedAttribute> attrs) {
1856  SmallVector<OpFoldResult> sizeValues =
1857  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1858  return b.getI64IntegerAttr(v);
1859  }));
1860  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1861  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1862  return b.getI64IntegerAttr(v);
1863  }));
1864  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1865  strideValues, attrs);
1866 }
1867 
1868 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1869  MemRefType resultType, Value source, Value offset,
1870  ValueRange sizes, ValueRange strides,
1871  ArrayRef<NamedAttribute> attrs) {
1872  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1873  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1874  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1875  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1876  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1877 }
1878 
1879 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1880 // completed automatically, like we have for subview and extract_slice.
1881 LogicalResult ReinterpretCastOp::verify() {
1882  // The source and result memrefs should be in the same memory space.
1883  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1884  auto resultType = llvm::cast<MemRefType>(getType());
1885  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1886  return emitError("different memory spaces specified for source type ")
1887  << srcType << " and result memref type " << resultType;
1888  if (srcType.getElementType() != resultType.getElementType())
1889  return emitError("different element types specified for source type ")
1890  << srcType << " and result memref type " << resultType;
1891 
1892  // Match sizes in result memref type and in static_sizes attribute.
1893  for (auto [idx, resultSize, expectedSize] :
1894  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1895  if (!ShapedType::isDynamic(resultSize) &&
1896  !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1897  return emitError("expected result type with size = ")
1898  << expectedSize << " instead of " << resultSize
1899  << " in dim = " << idx;
1900  }
1901 
1902  // Match offset and strides in static_offset and static_strides attributes. If
1903  // result memref type has no affine map specified, this will assume an
1904  // identity layout.
1905  int64_t resultOffset;
1906  SmallVector<int64_t, 4> resultStrides;
1907  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1908  return emitError("expected result type to have strided layout but found ")
1909  << resultType;
1910 
1911  // Match offset in result memref type and in static_offsets attribute.
1912  int64_t expectedOffset = getStaticOffsets().front();
1913  if (!ShapedType::isDynamic(resultOffset) &&
1914  !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
1915  return emitError("expected result type with offset = ")
1916  << expectedOffset << " instead of " << resultOffset;
1917 
1918  // Match strides in result memref type and in static_strides attribute.
1919  for (auto [idx, resultStride, expectedStride] :
1920  llvm::enumerate(resultStrides, getStaticStrides())) {
1921  if (!ShapedType::isDynamic(resultStride) &&
1922  !ShapedType::isDynamic(expectedStride) &&
1923  resultStride != expectedStride)
1924  return emitError("expected result type with stride = ")
1925  << expectedStride << " instead of " << resultStride
1926  << " in dim = " << idx;
1927  }
1928 
1929  return success();
1930 }
1931 
1932 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1933  Value src = getSource();
1934  auto getPrevSrc = [&]() -> Value {
1935  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1936  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1937  return prev.getSource();
1938 
1939  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1940  if (auto prev = src.getDefiningOp<CastOp>())
1941  return prev.getSource();
1942 
1943  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1944  // are 0.
1945  if (auto prev = src.getDefiningOp<SubViewOp>())
1946  if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1947  return isConstantIntValue(val, 0);
1948  }))
1949  return prev.getSource();
1950 
1951  return nullptr;
1952  };
1953 
1954  if (auto prevSrc = getPrevSrc()) {
1955  getSourceMutable().assign(prevSrc);
1956  return getResult();
1957  }
1958 
1959  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1960  if (!ShapedType::isDynamicShape(getType().getShape()) &&
1961  src.getType() == getType() && getStaticOffsets().front() == 0) {
1962  return src;
1963  }
1964 
1965  return nullptr;
1966 }
1967 
1968 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1971  ShapedType::isDynamic);
1972  return values;
1973 }
1974 
1975 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1976  SmallVector<OpFoldResult> values = getMixedStrides();
1978  ShapedType::isDynamic);
1979  return values;
1980 }
1981 
1982 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1983  SmallVector<OpFoldResult> values = getMixedOffsets();
1984  assert(values.size() == 1 &&
1985  "reinterpret_cast must have one and only one offset");
1987  ShapedType::isDynamic);
1988  return values[0];
1989 }
1990 
1991 namespace {
1992 /// Replace the sequence:
1993 /// ```
1994 /// base, offset, sizes, strides = extract_strided_metadata src
1995 /// dst = reinterpret_cast base to offset, sizes, strides
1996 /// ```
1997 /// With
1998 ///
1999 /// ```
2000 /// dst = memref.cast src
2001 /// ```
2002 ///
2003 /// Note: The cast operation is only inserted when the type of dst and src
2004 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
2005 ///
2006 /// This pattern also matches when the offset, sizes, and strides don't come
2007 /// directly from the `extract_strided_metadata`'s results but it can be
2008 /// statically proven that they would hold the same values.
2009 ///
2010 /// For instance, the following sequence would be replaced:
2011 /// ```
2012 /// base, offset, sizes, strides =
2013 /// extract_strided_metadata memref : memref<3x4xty>
2014 /// dst = reinterpret_cast base to 0, [3, 4], strides
2015 /// ```
2016 /// Because we know (thanks to the type of the input memref) that variable
2017 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
2018 ///
2019 /// Similarly, the following sequence would be replaced:
2020 /// ```
2021 /// c0 = arith.constant 0
2022 /// c4 = arith.constant 4
2023 /// base, offset, sizes, strides =
2024 /// extract_strided_metadata memref : memref<3x4xty>
2025 /// dst = reinterpret_cast base to c0, [3, c4], strides
2026 /// ```
2027 /// Because we know that `offset`and `c0` will hold 0
2028 /// and `c4` will hold 4.
2029 struct ReinterpretCastOpExtractStridedMetadataFolder
2030  : public OpRewritePattern<ReinterpretCastOp> {
2031 public:
2033 
2034  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2035  PatternRewriter &rewriter) const override {
2036  auto extractStridedMetadata =
2037  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2038  if (!extractStridedMetadata)
2039  return failure();
2040  // Check if the reinterpret cast reconstructs a memref with the exact same
2041  // properties as the extract strided metadata.
2042 
2043  // First, check that the strides are the same.
2044  SmallVector<OpFoldResult> extractStridesOfr =
2045  extractStridedMetadata.getConstifiedMixedStrides();
2046  SmallVector<OpFoldResult> reinterpretStridesOfr =
2047  op.getConstifiedMixedStrides();
2048  if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2049  return failure();
2050 
2051  unsigned rank = op.getType().getRank();
2052  for (unsigned i = 0; i < rank; ++i) {
2053  if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2054  return failure();
2055  }
2056 
2057  // Second, check the sizes.
2058  assert(extractStridedMetadata.getSizes().size() ==
2059  op.getMixedSizes().size() &&
2060  "Strides and sizes rank must match");
2061  SmallVector<OpFoldResult> extractSizesOfr =
2062  extractStridedMetadata.getConstifiedMixedSizes();
2063  SmallVector<OpFoldResult> reinterpretSizesOfr =
2064  op.getConstifiedMixedSizes();
2065  for (unsigned i = 0; i < rank; ++i) {
2066  if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2067  return failure();
2068  }
2069  // Finally, check the offset.
2070  assert(op.getMixedOffsets().size() == 1 &&
2071  "reinterpret_cast with more than one offset should have been "
2072  "rejected by the verifier");
2073  OpFoldResult extractOffsetOfr =
2074  extractStridedMetadata.getConstifiedMixedOffset();
2075  OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2076  if (extractOffsetOfr != reinterpretOffsetOfr)
2077  return failure();
2078 
2079  // At this point, we know that the back and forth between extract strided
2080  // metadata and reinterpret cast is a noop. However, the final type of the
2081  // reinterpret cast may not be exactly the same as the original memref.
2082  // E.g., it could be changing a dimension from static to dynamic. Check that
2083  // here and add a cast if necessary.
2084  Type srcTy = extractStridedMetadata.getSource().getType();
2085  if (srcTy == op.getResult().getType())
2086  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2087  else
2088  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2089  extractStridedMetadata.getSource());
2090 
2091  return success();
2092  }
2093 };
2094 } // namespace
2095 
2096 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2097  MLIRContext *context) {
2098  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2099 }
2100 
2101 //===----------------------------------------------------------------------===//
2102 // Reassociative reshape ops
2103 //===----------------------------------------------------------------------===//
2104 
2105 void CollapseShapeOp::getAsmResultNames(
2106  function_ref<void(Value, StringRef)> setNameFn) {
2107  setNameFn(getResult(), "collapse_shape");
2108 }
2109 
2110 void ExpandShapeOp::getAsmResultNames(
2111  function_ref<void(Value, StringRef)> setNameFn) {
2112  setNameFn(getResult(), "expand_shape");
2113 }
2114 
2115 LogicalResult ExpandShapeOp::reifyResultShapes(
2116  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2117  reifiedResultShapes = {
2118  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2119  return success();
2120 }
2121 
2122 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2123 /// result and operand. Layout maps are verified separately.
2124 ///
2125 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2126 /// allowed in a reassocation group.
2127 static LogicalResult
2129  ArrayRef<int64_t> expandedShape,
2130  ArrayRef<ReassociationIndices> reassociation,
2131  bool allowMultipleDynamicDimsPerGroup) {
2132  // There must be one reassociation group per collapsed dimension.
2133  if (collapsedShape.size() != reassociation.size())
2134  return op->emitOpError("invalid number of reassociation groups: found ")
2135  << reassociation.size() << ", expected " << collapsedShape.size();
2136 
2137  // The next expected expanded dimension index (while iterating over
2138  // reassociation indices).
2139  int64_t nextDim = 0;
2140  for (const auto &it : llvm::enumerate(reassociation)) {
2141  ReassociationIndices group = it.value();
2142  int64_t collapsedDim = it.index();
2143 
2144  bool foundDynamic = false;
2145  for (int64_t expandedDim : group) {
2146  if (expandedDim != nextDim++)
2147  return op->emitOpError("reassociation indices must be contiguous");
2148 
2149  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2150  return op->emitOpError("reassociation index ")
2151  << expandedDim << " is out of bounds";
2152 
2153  // Check if there are multiple dynamic dims in a reassociation group.
2154  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2155  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2156  return op->emitOpError(
2157  "at most one dimension in a reassociation group may be dynamic");
2158  foundDynamic = true;
2159  }
2160  }
2161 
2162  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2163  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2164  return op->emitOpError("collapsed dim (")
2165  << collapsedDim
2166  << ") must be dynamic if and only if reassociation group is "
2167  "dynamic";
2168 
2169  // If all dims in the reassociation group are static, the size of the
2170  // collapsed dim can be verified.
2171  if (!foundDynamic) {
2172  int64_t groupSize = 1;
2173  for (int64_t expandedDim : group)
2174  groupSize *= expandedShape[expandedDim];
2175  if (groupSize != collapsedShape[collapsedDim])
2176  return op->emitOpError("collapsed dim size (")
2177  << collapsedShape[collapsedDim]
2178  << ") must equal reassociation group size (" << groupSize << ")";
2179  }
2180  }
2181 
2182  if (collapsedShape.empty()) {
2183  // Rank 0: All expanded dimensions must be 1.
2184  for (int64_t d : expandedShape)
2185  if (d != 1)
2186  return op->emitOpError(
2187  "rank 0 memrefs can only be extended/collapsed with/from ones");
2188  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2189  // Rank >= 1: Number of dimensions among all reassociation groups must match
2190  // the result memref rank.
2191  return op->emitOpError("expanded rank (")
2192  << expandedShape.size()
2193  << ") inconsistent with number of reassociation indices (" << nextDim
2194  << ")";
2195  }
2196 
2197  return success();
2198 }
2199 
2200 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2201  return getSymbolLessAffineMaps(getReassociationExprs());
2202 }
2203 
2204 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2206  getReassociationIndices());
2207 }
2208 
2209 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2210  return getSymbolLessAffineMaps(getReassociationExprs());
2211 }
2212 
2213 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2215  getReassociationIndices());
2216 }
2217 
2218 /// Compute the layout map after expanding a given source MemRef type with the
2219 /// specified reassociation indices.
2220 static FailureOr<StridedLayoutAttr>
2221 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2222  ArrayRef<ReassociationIndices> reassociation) {
2223  int64_t srcOffset;
2224  SmallVector<int64_t> srcStrides;
2225  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2226  return failure();
2227  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2228 
2229  // 1-1 mapping between srcStrides and reassociation packs.
2230  // Each srcStride starts with the given value and gets expanded according to
2231  // the proper entries in resultShape.
2232  // Example:
2233  // srcStrides = [10000, 1 , 100 ],
2234  // reassociations = [ [0], [1], [2, 3, 4]],
2235  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2236  // -> For the purpose of stride calculation, the useful sizes are:
2237  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2238  // resultStrides = [10000, 1, 600, 200, 100]
2239  // Note that a stride does not get expanded along the first entry of each
2240  // shape pack.
2241  SmallVector<int64_t> reverseResultStrides;
2242  reverseResultStrides.reserve(resultShape.size());
2243  unsigned shapeIndex = resultShape.size() - 1;
2244  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2245  ReassociationIndices reassoc = std::get<0>(it);
2246  int64_t currentStrideToExpand = std::get<1>(it);
2247  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2248  reverseResultStrides.push_back(currentStrideToExpand);
2249  currentStrideToExpand =
2250  (SaturatedInteger::wrap(currentStrideToExpand) *
2251  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2252  .asInteger();
2253  }
2254  }
2255  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2256  resultStrides.resize(resultShape.size(), 1);
2257  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2258 }
2259 
2260 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2261  MemRefType srcType, ArrayRef<int64_t> resultShape,
2262  ArrayRef<ReassociationIndices> reassociation) {
2263  if (srcType.getLayout().isIdentity()) {
2264  // If the source is contiguous (i.e., no layout map specified), so is the
2265  // result.
2266  MemRefLayoutAttrInterface layout;
2267  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2268  srcType.getMemorySpace());
2269  }
2270 
2271  // Source may not be contiguous. Compute the layout map.
2272  FailureOr<StridedLayoutAttr> computedLayout =
2273  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2274  if (failed(computedLayout))
2275  return failure();
2276  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2277  srcType.getMemorySpace());
2278 }
2279 
2280 FailureOr<SmallVector<OpFoldResult>>
2281 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2282  MemRefType expandedType,
2283  ArrayRef<ReassociationIndices> reassociation,
2284  ArrayRef<OpFoldResult> inputShape) {
2285  std::optional<SmallVector<OpFoldResult>> outputShape =
2286  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2287  inputShape);
2288  if (!outputShape)
2289  return failure();
2290  return *outputShape;
2291 }
2292 
2293 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2294  Type resultType, Value src,
2295  ArrayRef<ReassociationIndices> reassociation,
2296  ArrayRef<OpFoldResult> outputShape) {
2297  auto [staticOutputShape, dynamicOutputShape] =
2299  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2300  getReassociationIndicesAttribute(builder, reassociation),
2301  dynamicOutputShape, staticOutputShape);
2302 }
2303 
2304 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2305  Type resultType, Value src,
2306  ArrayRef<ReassociationIndices> reassociation) {
2307  SmallVector<OpFoldResult> inputShape =
2308  getMixedSizes(builder, result.location, src);
2309  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2310  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2311  builder, result.location, memrefResultTy, reassociation, inputShape);
2312  // Failure of this assertion usually indicates presence of multiple
2313  // dynamic dimensions in the same reassociation group.
2314  assert(succeeded(outputShape) && "unable to infer output shape");
2315  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2316 }
2317 
2318 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2319  ArrayRef<int64_t> resultShape, Value src,
2320  ArrayRef<ReassociationIndices> reassociation) {
2321  // Only ranked memref source values are supported.
2322  auto srcType = llvm::cast<MemRefType>(src.getType());
2323  FailureOr<MemRefType> resultType =
2324  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2325  // Failure of this assertion usually indicates a problem with the source
2326  // type, e.g., could not get strides/offset.
2327  assert(succeeded(resultType) && "could not compute layout");
2328  build(builder, result, *resultType, src, reassociation);
2329 }
2330 
2331 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2332  ArrayRef<int64_t> resultShape, Value src,
2333  ArrayRef<ReassociationIndices> reassociation,
2334  ArrayRef<OpFoldResult> outputShape) {
2335  // Only ranked memref source values are supported.
2336  auto srcType = llvm::cast<MemRefType>(src.getType());
2337  FailureOr<MemRefType> resultType =
2338  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2339  // Failure of this assertion usually indicates a problem with the source
2340  // type, e.g., could not get strides/offset.
2341  assert(succeeded(resultType) && "could not compute layout");
2342  build(builder, result, *resultType, src, reassociation, outputShape);
2343 }
2344 
2345 LogicalResult ExpandShapeOp::verify() {
2346  MemRefType srcType = getSrcType();
2347  MemRefType resultType = getResultType();
2348 
2349  if (srcType.getRank() > resultType.getRank()) {
2350  auto r0 = srcType.getRank();
2351  auto r1 = resultType.getRank();
2352  return emitOpError("has source rank ")
2353  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2354  << r0 << " > " << r1 << ").";
2355  }
2356 
2357  // Verify result shape.
2358  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2359  resultType.getShape(),
2360  getReassociationIndices(),
2361  /*allowMultipleDynamicDimsPerGroup=*/true)))
2362  return failure();
2363 
2364  // Compute expected result type (including layout map).
2365  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2366  srcType, resultType.getShape(), getReassociationIndices());
2367  if (failed(expectedResultType))
2368  return emitOpError("invalid source layout map");
2369 
2370  // Check actual result type.
2371  if (*expectedResultType != resultType)
2372  return emitOpError("expected expanded type to be ")
2373  << *expectedResultType << " but found " << resultType;
2374 
2375  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2376  return emitOpError("expected number of static shape bounds to be equal to "
2377  "the output rank (")
2378  << resultType.getRank() << ") but found "
2379  << getStaticOutputShape().size() << " inputs instead";
2380 
2381  if ((int64_t)getOutputShape().size() !=
2382  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2383  return emitOpError("mismatch in dynamic dims in output_shape and "
2384  "static_output_shape: static_output_shape has ")
2385  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2386  << " dynamic dims while output_shape has " << getOutputShape().size()
2387  << " values";
2388 
2389  // Verify if provided output shapes are in agreement with output type.
2390  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2391  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2392  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2393  if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2394  return emitOpError("invalid output shape provided at pos ") << pos;
2395  }
2396  }
2397 
2398  return success();
2399 }
2400 
2401 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2402  MLIRContext *context) {
2403  results.add<
2406 }
2407 
2408 /// Compute the layout map after collapsing a given source MemRef type with the
2409 /// specified reassociation indices.
2410 ///
2411 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2412 /// not possible to check this by inspecting a MemRefType in the general case.
2413 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2414 /// be valid (and thus accepted by this function) unless `strict = true`.
2415 static FailureOr<StridedLayoutAttr>
2416 computeCollapsedLayoutMap(MemRefType srcType,
2417  ArrayRef<ReassociationIndices> reassociation,
2418  bool strict = false) {
2419  int64_t srcOffset;
2420  SmallVector<int64_t> srcStrides;
2421  auto srcShape = srcType.getShape();
2422  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2423  return failure();
2424 
2425  // The result stride of a reassociation group is the stride of the last entry
2426  // of the reassociation. (TODO: Should be the minimum stride in the
2427  // reassociation because strides are not necessarily sorted. E.g., when using
2428  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2429  // strides are meaningless and could have any arbitrary value.
2430  SmallVector<int64_t> resultStrides;
2431  resultStrides.reserve(reassociation.size());
2432  for (const ReassociationIndices &reassoc : reassociation) {
2433  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2434  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2435  ref = ref.drop_back();
2436  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2437  resultStrides.push_back(srcStrides[ref.back()]);
2438  } else {
2439  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2440  // the corresponding stride may have to be skipped. (See above comment.)
2441  // Therefore, the result stride cannot be statically determined and must
2442  // be dynamic.
2443  resultStrides.push_back(ShapedType::kDynamic);
2444  }
2445  }
2446 
2447  // Validate that each reassociation group is contiguous.
2448  unsigned resultStrideIndex = resultStrides.size() - 1;
2449  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2450  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2451  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2452  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2453  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2454 
2455  // Both source and result stride must have the same static value. In that
2456  // case, we can be sure, that the dimensions are collapsible (because they
2457  // are contiguous).
2458  // If `strict = false` (default during op verification), we accept cases
2459  // where one or both strides are dynamic. This is best effort: We reject
2460  // ops where obviously non-contiguous dims are collapsed, but accept ops
2461  // where we cannot be sure statically. Such ops may fail at runtime. See
2462  // the op documentation for details.
2463  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2464  if (strict && (stride.saturated || srcStride.saturated))
2465  return failure();
2466 
2467  // Dimensions of size 1 should be skipped, because their strides are
2468  // meaningless and could have any arbitrary value.
2469  if (srcShape[idx - 1] == 1)
2470  continue;
2471 
2472  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2473  return failure();
2474  }
2475  }
2476  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2477 }
2478 
2479 bool CollapseShapeOp::isGuaranteedCollapsible(
2480  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2481  // MemRefs with identity layout are always collapsible.
2482  if (srcType.getLayout().isIdentity())
2483  return true;
2484 
2485  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2486  /*strict=*/true));
2487 }
2488 
2489 MemRefType CollapseShapeOp::computeCollapsedType(
2490  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2491  SmallVector<int64_t> resultShape;
2492  resultShape.reserve(reassociation.size());
2493  for (const ReassociationIndices &group : reassociation) {
2494  auto groupSize = SaturatedInteger::wrap(1);
2495  for (int64_t srcDim : group)
2496  groupSize =
2497  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2498  resultShape.push_back(groupSize.asInteger());
2499  }
2500 
2501  if (srcType.getLayout().isIdentity()) {
2502  // If the source is contiguous (i.e., no layout map specified), so is the
2503  // result.
2504  MemRefLayoutAttrInterface layout;
2505  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2506  srcType.getMemorySpace());
2507  }
2508 
2509  // Source may not be fully contiguous. Compute the layout map.
2510  // Note: Dimensions that are collapsed into a single dim are assumed to be
2511  // contiguous.
2512  FailureOr<StridedLayoutAttr> computedLayout =
2513  computeCollapsedLayoutMap(srcType, reassociation);
2514  assert(succeeded(computedLayout) &&
2515  "invalid source layout map or collapsing non-contiguous dims");
2516  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2517  srcType.getMemorySpace());
2518 }
2519 
2520 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2521  ArrayRef<ReassociationIndices> reassociation,
2522  ArrayRef<NamedAttribute> attrs) {
2523  auto srcType = llvm::cast<MemRefType>(src.getType());
2524  MemRefType resultType =
2525  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2527  getReassociationIndicesAttribute(b, reassociation));
2528  build(b, result, resultType, src, attrs);
2529 }
2530 
2531 LogicalResult CollapseShapeOp::verify() {
2532  MemRefType srcType = getSrcType();
2533  MemRefType resultType = getResultType();
2534 
2535  if (srcType.getRank() < resultType.getRank()) {
2536  auto r0 = srcType.getRank();
2537  auto r1 = resultType.getRank();
2538  return emitOpError("has source rank ")
2539  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2540  << r0 << " < " << r1 << ").";
2541  }
2542 
2543  // Verify result shape.
2544  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2545  srcType.getShape(), getReassociationIndices(),
2546  /*allowMultipleDynamicDimsPerGroup=*/true)))
2547  return failure();
2548 
2549  // Compute expected result type (including layout map).
2550  MemRefType expectedResultType;
2551  if (srcType.getLayout().isIdentity()) {
2552  // If the source is contiguous (i.e., no layout map specified), so is the
2553  // result.
2554  MemRefLayoutAttrInterface layout;
2555  expectedResultType =
2556  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2557  srcType.getMemorySpace());
2558  } else {
2559  // Source may not be fully contiguous. Compute the layout map.
2560  // Note: Dimensions that are collapsed into a single dim are assumed to be
2561  // contiguous.
2562  FailureOr<StridedLayoutAttr> computedLayout =
2563  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2564  if (failed(computedLayout))
2565  return emitOpError(
2566  "invalid source layout map or collapsing non-contiguous dims");
2567  expectedResultType =
2568  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2569  *computedLayout, srcType.getMemorySpace());
2570  }
2571 
2572  if (expectedResultType != resultType)
2573  return emitOpError("expected collapsed type to be ")
2574  << expectedResultType << " but found " << resultType;
2575 
2576  return success();
2577 }
2578 
2580  : public OpRewritePattern<CollapseShapeOp> {
2581 public:
2583 
2584  LogicalResult matchAndRewrite(CollapseShapeOp op,
2585  PatternRewriter &rewriter) const override {
2586  auto cast = op.getOperand().getDefiningOp<CastOp>();
2587  if (!cast)
2588  return failure();
2589 
2590  if (!CastOp::canFoldIntoConsumerOp(cast))
2591  return failure();
2592 
2593  Type newResultType = CollapseShapeOp::computeCollapsedType(
2594  llvm::cast<MemRefType>(cast.getOperand().getType()),
2595  op.getReassociationIndices());
2596 
2597  if (newResultType == op.getResultType()) {
2598  rewriter.modifyOpInPlace(
2599  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2600  } else {
2601  Value newOp = rewriter.create<CollapseShapeOp>(
2602  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2603  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2604  }
2605  return success();
2606  }
2607 };
2608 
2609 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2610  MLIRContext *context) {
2611  results.add<
2613  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2614  memref::DimOp, MemRefType>,
2616 }
2617 
2618 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2619  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2620  adaptor.getOperands());
2621 }
2622 
2623 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2624  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2625  adaptor.getOperands());
2626 }
2627 
2628 //===----------------------------------------------------------------------===//
2629 // ReshapeOp
2630 //===----------------------------------------------------------------------===//
2631 
2632 void ReshapeOp::getAsmResultNames(
2633  function_ref<void(Value, StringRef)> setNameFn) {
2634  setNameFn(getResult(), "reshape");
2635 }
2636 
2637 LogicalResult ReshapeOp::verify() {
2638  Type operandType = getSource().getType();
2639  Type resultType = getResult().getType();
2640 
2641  Type operandElementType =
2642  llvm::cast<ShapedType>(operandType).getElementType();
2643  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2644  if (operandElementType != resultElementType)
2645  return emitOpError("element types of source and destination memref "
2646  "types should be the same");
2647 
2648  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2649  if (!operandMemRefType.getLayout().isIdentity())
2650  return emitOpError("source memref type should have identity affine map");
2651 
2652  int64_t shapeSize =
2653  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2654  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2655  if (resultMemRefType) {
2656  if (!resultMemRefType.getLayout().isIdentity())
2657  return emitOpError("result memref type should have identity affine map");
2658  if (shapeSize == ShapedType::kDynamic)
2659  return emitOpError("cannot use shape operand with dynamic length to "
2660  "reshape to statically-ranked memref type");
2661  if (shapeSize != resultMemRefType.getRank())
2662  return emitOpError(
2663  "length of shape operand differs from the result's memref rank");
2664  }
2665  return success();
2666 }
2667 
2668 //===----------------------------------------------------------------------===//
2669 // StoreOp
2670 //===----------------------------------------------------------------------===//
2671 
2672 LogicalResult StoreOp::verify() {
2673  if (getNumOperands() != 2 + getMemRefType().getRank())
2674  return emitOpError("store index operand count not equal to memref rank");
2675 
2676  return success();
2677 }
2678 
2679 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2680  SmallVectorImpl<OpFoldResult> &results) {
2681  /// store(memrefcast) -> store
2682  return foldMemRefCast(*this, getValueToStore());
2683 }
2684 
2685 //===----------------------------------------------------------------------===//
2686 // SubViewOp
2687 //===----------------------------------------------------------------------===//
2688 
2689 void SubViewOp::getAsmResultNames(
2690  function_ref<void(Value, StringRef)> setNameFn) {
2691  setNameFn(getResult(), "subview");
2692 }
2693 
2694 /// A subview result type can be fully inferred from the source type and the
2695 /// static representation of offsets, sizes and strides. Special sentinels
2696 /// encode the dynamic case.
2697 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2698  ArrayRef<int64_t> staticOffsets,
2699  ArrayRef<int64_t> staticSizes,
2700  ArrayRef<int64_t> staticStrides) {
2701  unsigned rank = sourceMemRefType.getRank();
2702  (void)rank;
2703  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2704  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2705  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2706 
2707  // Extract source offset and strides.
2708  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
2709 
2710  // Compute target offset whose value is:
2711  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2712  int64_t targetOffset = sourceOffset;
2713  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2714  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2715  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2716  SaturatedInteger::wrap(staticOffset) *
2717  SaturatedInteger::wrap(sourceStride))
2718  .asInteger();
2719  }
2720 
2721  // Compute target stride whose value is:
2722  // `sourceStrides_i * staticStrides_i`.
2723  SmallVector<int64_t, 4> targetStrides;
2724  targetStrides.reserve(staticOffsets.size());
2725  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2726  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2727  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2728  SaturatedInteger::wrap(staticStride))
2729  .asInteger());
2730  }
2731 
2732  // The type is now known.
2733  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2734  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2735  targetOffset, targetStrides),
2736  sourceMemRefType.getMemorySpace());
2737 }
2738 
2739 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2740  ArrayRef<OpFoldResult> offsets,
2741  ArrayRef<OpFoldResult> sizes,
2742  ArrayRef<OpFoldResult> strides) {
2743  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2744  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2745  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2746  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2747  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2748  if (!hasValidSizesOffsets(staticOffsets))
2749  return {};
2750  if (!hasValidSizesOffsets(staticSizes))
2751  return {};
2752  if (!hasValidStrides(staticStrides))
2753  return {};
2754  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2755  staticSizes, staticStrides);
2756 }
2757 
2758 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2759  MemRefType sourceRankedTensorType,
2760  ArrayRef<int64_t> offsets,
2761  ArrayRef<int64_t> sizes,
2762  ArrayRef<int64_t> strides) {
2763  auto inferredType = llvm::cast<MemRefType>(
2764  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2765  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2766  "expected ");
2767  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2768  return inferredType;
2769 
2770  // Compute which dimensions are dropped.
2771  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2772  computeRankReductionMask(inferredType.getShape(), resultShape);
2773  assert(dimsToProject.has_value() && "invalid rank reduction");
2774 
2775  // Compute the layout and result type.
2776  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2777  SmallVector<int64_t> rankReducedStrides;
2778  rankReducedStrides.reserve(resultShape.size());
2779  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2780  if (!dimsToProject->contains(idx))
2781  rankReducedStrides.push_back(value);
2782  }
2783  return MemRefType::get(resultShape, inferredType.getElementType(),
2784  StridedLayoutAttr::get(inferredLayout.getContext(),
2785  inferredLayout.getOffset(),
2786  rankReducedStrides),
2787  inferredType.getMemorySpace());
2788 }
2789 
2790 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2791  MemRefType sourceRankedTensorType,
2792  ArrayRef<OpFoldResult> offsets,
2793  ArrayRef<OpFoldResult> sizes,
2794  ArrayRef<OpFoldResult> strides) {
2795  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2796  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2797  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2798  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2799  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2800  return SubViewOp::inferRankReducedResultType(
2801  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2802  staticStrides);
2803 }
2804 
2805 // Build a SubViewOp with mixed static and dynamic entries and custom result
2806 // type. If the type passed is nullptr, it is inferred.
2807 void SubViewOp::build(OpBuilder &b, OperationState &result,
2808  MemRefType resultType, Value source,
2809  ArrayRef<OpFoldResult> offsets,
2810  ArrayRef<OpFoldResult> sizes,
2811  ArrayRef<OpFoldResult> strides,
2812  ArrayRef<NamedAttribute> attrs) {
2813  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2814  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2815  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2816  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2817  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2818  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2819  // Structuring implementation this way avoids duplication between builders.
2820  if (!resultType) {
2821  resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2822  sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2823  }
2824  result.addAttributes(attrs);
2825  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2826  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2827  b.getDenseI64ArrayAttr(staticSizes),
2828  b.getDenseI64ArrayAttr(staticStrides));
2829 }
2830 
2831 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2832 // type.
2833 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2834  ArrayRef<OpFoldResult> offsets,
2835  ArrayRef<OpFoldResult> sizes,
2836  ArrayRef<OpFoldResult> strides,
2837  ArrayRef<NamedAttribute> attrs) {
2838  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2839 }
2840 
2841 // Build a SubViewOp with static entries and inferred result type.
2842 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2843  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2844  ArrayRef<int64_t> strides,
2845  ArrayRef<NamedAttribute> attrs) {
2846  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2847  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2848  return b.getI64IntegerAttr(v);
2849  }));
2850  SmallVector<OpFoldResult> sizeValues =
2851  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2852  return b.getI64IntegerAttr(v);
2853  }));
2854  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2855  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2856  return b.getI64IntegerAttr(v);
2857  }));
2858  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2859 }
2860 
2861 // Build a SubViewOp with dynamic entries and custom result type. If the
2862 // type passed is nullptr, it is inferred.
2863 void SubViewOp::build(OpBuilder &b, OperationState &result,
2864  MemRefType resultType, Value source,
2865  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2866  ArrayRef<int64_t> strides,
2867  ArrayRef<NamedAttribute> attrs) {
2868  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2869  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2870  return b.getI64IntegerAttr(v);
2871  }));
2872  SmallVector<OpFoldResult> sizeValues =
2873  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2874  return b.getI64IntegerAttr(v);
2875  }));
2876  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2877  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2878  return b.getI64IntegerAttr(v);
2879  }));
2880  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2881  attrs);
2882 }
2883 
2884 // Build a SubViewOp with dynamic entries and custom result type. If the type
2885 // passed is nullptr, it is inferred.
2886 void SubViewOp::build(OpBuilder &b, OperationState &result,
2887  MemRefType resultType, Value source, ValueRange offsets,
2888  ValueRange sizes, ValueRange strides,
2889  ArrayRef<NamedAttribute> attrs) {
2890  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2891  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2892  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2893  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2894  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2895  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2896  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2897 }
2898 
2899 // Build a SubViewOp with dynamic entries and inferred result type.
2900 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2901  ValueRange offsets, ValueRange sizes, ValueRange strides,
2902  ArrayRef<NamedAttribute> attrs) {
2903  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2904 }
2905 
2906 /// For ViewLikeOpInterface.
2907 Value SubViewOp::getViewSource() { return getSource(); }
2908 
2909 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2910 /// static value).
2911 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2912  int64_t t1Offset, t2Offset;
2913  SmallVector<int64_t> t1Strides, t2Strides;
2914  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2915  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2916  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2917 }
2918 
2919 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2920 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2921 /// marked as dropped in `droppedDims`.
2922 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2923  const llvm::SmallBitVector &droppedDims) {
2924  assert(size_t(t1.getRank()) == droppedDims.size() &&
2925  "incorrect number of bits");
2926  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2927  "incorrect number of dropped dims");
2928  int64_t t1Offset, t2Offset;
2929  SmallVector<int64_t> t1Strides, t2Strides;
2930  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2931  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2932  if (failed(res1) || failed(res2))
2933  return false;
2934  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2935  if (droppedDims[i])
2936  continue;
2937  if (t1Strides[i] != t2Strides[j])
2938  return false;
2939  ++j;
2940  }
2941  return true;
2942 }
2943 
2945  Operation *op, Type expectedType) {
2946  auto memrefType = llvm::cast<ShapedType>(expectedType);
2947  switch (result) {
2949  return success();
2951  return op->emitError("expected result rank to be smaller or equal to ")
2952  << "the source rank. ";
2954  return op->emitError("expected result type to be ")
2955  << expectedType
2956  << " or a rank-reduced version. (mismatch of result sizes) ";
2958  return op->emitError("expected result element type to be ")
2959  << memrefType.getElementType();
2961  return op->emitError("expected result and source memory spaces to match.");
2963  return op->emitError("expected result type to be ")
2964  << expectedType
2965  << " or a rank-reduced version. (mismatch of result layout) ";
2966  }
2967  llvm_unreachable("unexpected subview verification result");
2968 }
2969 
2970 /// Verifier for SubViewOp.
2971 LogicalResult SubViewOp::verify() {
2972  MemRefType baseType = getSourceType();
2973  MemRefType subViewType = getType();
2974 
2975  // The base memref and the view memref should be in the same memory space.
2976  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2977  return emitError("different memory spaces specified for base memref "
2978  "type ")
2979  << baseType << " and subview memref type " << subViewType;
2980 
2981  // Verify that the base memref type has a strided layout map.
2982  if (!isStrided(baseType))
2983  return emitError("base type ") << baseType << " is not strided";
2984 
2985  // Compute the expected result type, assuming that there are no rank
2986  // reductions.
2987  auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2988  baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2989 
2990  // Verify all properties of a shaped type: rank, element type and dimension
2991  // sizes. This takes into account potential rank reductions.
2992  auto shapedTypeVerification = isRankReducedType(
2993  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2994  if (shapedTypeVerification != SliceVerificationResult::Success)
2995  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2996 
2997  // Make sure that the memory space did not change.
2998  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3000  *this, expectedType);
3001 
3002  // Verify the offset of the layout map.
3003  if (!haveCompatibleOffsets(expectedType, subViewType))
3005  *this, expectedType);
3006 
3007  // The only thing that's left to verify now are the strides. First, compute
3008  // the unused dimensions due to rank reductions. We have to look at sizes and
3009  // strides to decide which dimensions were dropped. This function also
3010  // partially verifies strides in case of rank reductions.
3011  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3012  getMixedSizes());
3013  if (failed(unusedDims))
3015  *this, expectedType);
3016 
3017  // Strides must match.
3018  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3020  *this, expectedType);
3021 
3022  return success();
3023 }
3024 
3025 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
3026  return os << "range " << range.offset << ":" << range.size << ":"
3027  << range.stride;
3028 }
3029 
3030 /// Return the list of Range (i.e. offset, size, stride). Each Range
3031 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3032 /// with `b` at location `loc`.
3033 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3034  OpBuilder &b, Location loc) {
3035  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3036  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3037  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3039  unsigned rank = ranks[0];
3040  res.reserve(rank);
3041  for (unsigned idx = 0; idx < rank; ++idx) {
3042  Value offset =
3043  op.isDynamicOffset(idx)
3044  ? op.getDynamicOffset(idx)
3045  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
3046  Value size =
3047  op.isDynamicSize(idx)
3048  ? op.getDynamicSize(idx)
3049  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
3050  Value stride =
3051  op.isDynamicStride(idx)
3052  ? op.getDynamicStride(idx)
3053  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
3054  res.emplace_back(Range{offset, size, stride});
3055  }
3056  return res;
3057 }
3058 
3059 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3060 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3061 /// the rank of the inferred result type if `currentResultType` is lower rank
3062 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3063 /// together with the result type. In this case, it is important to compute
3064 /// the dropped dimensions using `currentSourceType` whose strides align with
3065 /// `currentResultType`.
3067  MemRefType currentResultType, MemRefType currentSourceType,
3068  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3069  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3070  auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
3071  sourceType, mixedOffsets, mixedSizes, mixedStrides));
3072  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3073  currentSourceType, currentResultType, mixedSizes);
3074  if (failed(unusedDims))
3075  return nullptr;
3076 
3077  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3078  SmallVector<int64_t> shape, strides;
3079  unsigned numDimsAfterReduction =
3080  nonRankReducedType.getRank() - unusedDims->count();
3081  shape.reserve(numDimsAfterReduction);
3082  strides.reserve(numDimsAfterReduction);
3083  for (const auto &[idx, size, stride] :
3084  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3085  nonRankReducedType.getShape(), layout.getStrides())) {
3086  if (unusedDims->test(idx))
3087  continue;
3088  shape.push_back(size);
3089  strides.push_back(stride);
3090  }
3091 
3092  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3093  StridedLayoutAttr::get(sourceType.getContext(),
3094  layout.getOffset(), strides),
3095  nonRankReducedType.getMemorySpace());
3096 }
3097 
3099  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3100  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3101  unsigned rank = memrefType.getRank();
3102  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3103  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3104  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3105  auto targetType =
3106  llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
3107  targetShape, memrefType, offsets, sizes, strides));
3108  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3109  sizes, strides);
3110 }
3111 
3112 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3113  Value value,
3114  ArrayRef<int64_t> desiredShape) {
3115  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3116  assert(sourceMemrefType && "not a ranked memref type");
3117  auto sourceShape = sourceMemrefType.getShape();
3118  if (sourceShape.equals(desiredShape))
3119  return value;
3120  auto maybeRankReductionMask =
3121  mlir::computeRankReductionMask(sourceShape, desiredShape);
3122  if (!maybeRankReductionMask)
3123  return failure();
3124  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3125 }
3126 
3127 /// Helper method to check if a `subview` operation is trivially a no-op. This
3128 /// is the case if the all offsets are zero, all strides are 1, and the source
3129 /// shape is same as the size of the subview. In such cases, the subview can
3130 /// be folded into its source.
3131 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3132  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3133  return false;
3134 
3135  auto mixedOffsets = subViewOp.getMixedOffsets();
3136  auto mixedSizes = subViewOp.getMixedSizes();
3137  auto mixedStrides = subViewOp.getMixedStrides();
3138 
3139  // Check offsets are zero.
3140  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3141  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3142  return !intValue || intValue.value() != 0;
3143  }))
3144  return false;
3145 
3146  // Check strides are one.
3147  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3148  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3149  return !intValue || intValue.value() != 1;
3150  }))
3151  return false;
3152 
3153  // Check all size values are static and matches the (static) source shape.
3154  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3155  for (const auto &size : llvm::enumerate(mixedSizes)) {
3156  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3157  if (!intValue || *intValue != sourceShape[size.index()])
3158  return false;
3159  }
3160  // All conditions met. The `SubViewOp` is foldable as a no-op.
3161  return true;
3162 }
3163 
3164 namespace {
3165 /// Pattern to rewrite a subview op with MemRefCast arguments.
3166 /// This essentially pushes memref.cast past its consuming subview when
3167 /// `canFoldIntoConsumerOp` is true.
3168 ///
3169 /// Example:
3170 /// ```
3171 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3172 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3173 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3174 /// ```
3175 /// is rewritten into:
3176 /// ```
3177 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3178 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3179 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3180 /// ```
3181 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3182 public:
3184 
3185  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3186  PatternRewriter &rewriter) const override {
3187  // Any constant operand, just return to let SubViewOpConstantFolder kick
3188  // in.
3189  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3190  return matchPattern(operand, matchConstantIndex());
3191  }))
3192  return failure();
3193 
3194  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3195  if (!castOp)
3196  return failure();
3197 
3198  if (!CastOp::canFoldIntoConsumerOp(castOp))
3199  return failure();
3200 
3201  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3202  // the MemRefCastOp source operand type to infer the result type and the
3203  // current SubViewOp source operand type to compute the dropped dimensions
3204  // if the operation is rank-reducing.
3205  auto resultType = getCanonicalSubViewResultType(
3206  subViewOp.getType(), subViewOp.getSourceType(),
3207  llvm::cast<MemRefType>(castOp.getSource().getType()),
3208  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3209  subViewOp.getMixedStrides());
3210  if (!resultType)
3211  return failure();
3212 
3213  Value newSubView = rewriter.create<SubViewOp>(
3214  subViewOp.getLoc(), resultType, castOp.getSource(),
3215  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3216  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3217  subViewOp.getStaticStrides());
3218  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3219  newSubView);
3220  return success();
3221  }
3222 };
3223 
3224 /// Canonicalize subview ops that are no-ops. When the source shape is not
3225 /// same as a result shape due to use of `affine_map`.
3226 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3227 public:
3229 
3230  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3231  PatternRewriter &rewriter) const override {
3232  if (!isTrivialSubViewOp(subViewOp))
3233  return failure();
3234  if (subViewOp.getSourceType() == subViewOp.getType()) {
3235  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3236  return success();
3237  }
3238  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3239  subViewOp.getSource());
3240  return success();
3241  }
3242 };
3243 } // namespace
3244 
3245 /// Return the canonical type of the result of a subview.
3247  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3248  ArrayRef<OpFoldResult> mixedSizes,
3249  ArrayRef<OpFoldResult> mixedStrides) {
3250  // Infer a memref type without taking into account any rank reductions.
3251  auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3252  mixedSizes, mixedStrides);
3253  if (!resTy)
3254  return {};
3255  MemRefType nonReducedType = cast<MemRefType>(resTy);
3256 
3257  // Directly return the non-rank reduced type if there are no dropped dims.
3258  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3259  if (droppedDims.none())
3260  return nonReducedType;
3261 
3262  // Take the strides and offset from the non-rank reduced type.
3263  auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
3264 
3265  // Drop dims from shape and strides.
3266  SmallVector<int64_t> targetShape;
3267  SmallVector<int64_t> targetStrides;
3268  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3269  if (droppedDims.test(i))
3270  continue;
3271  targetStrides.push_back(nonReducedStrides[i]);
3272  targetShape.push_back(nonReducedType.getDimSize(i));
3273  }
3274 
3275  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3276  StridedLayoutAttr::get(nonReducedType.getContext(),
3277  offset, targetStrides),
3278  nonReducedType.getMemorySpace());
3279  }
3280 };
3281 
3282 /// A canonicalizer wrapper to replace SubViewOps.
3284  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3285  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3286  }
3287 };
3288 
3289 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3290  MLIRContext *context) {
3291  results
3294  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3295 }
3296 
3297 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3298  MemRefType sourceMemrefType = getSource().getType();
3299  MemRefType resultMemrefType = getResult().getType();
3300  auto resultLayout =
3301  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3302 
3303  if (resultMemrefType == sourceMemrefType &&
3304  resultMemrefType.hasStaticShape() &&
3305  (!resultLayout || resultLayout.hasStaticLayout())) {
3306  return getViewSource();
3307  }
3308 
3309  // Fold subview(subview(x)), where both subviews have the same size and the
3310  // second subview's offsets are all zero. (I.e., the second subview is a
3311  // no-op.)
3312  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3313  auto srcSizes = srcSubview.getMixedSizes();
3314  auto sizes = getMixedSizes();
3315  auto offsets = getMixedOffsets();
3316  bool allOffsetsZero = llvm::all_of(
3317  offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3318  auto strides = getMixedStrides();
3319  bool allStridesOne = llvm::all_of(
3320  strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3321  bool allSizesSame = llvm::equal(sizes, srcSizes);
3322  if (allOffsetsZero && allStridesOne && allSizesSame &&
3323  resultMemrefType == sourceMemrefType)
3324  return getViewSource();
3325  }
3326 
3327  return {};
3328 }
3329 
3330 //===----------------------------------------------------------------------===//
3331 // TransposeOp
3332 //===----------------------------------------------------------------------===//
3333 
3334 void TransposeOp::getAsmResultNames(
3335  function_ref<void(Value, StringRef)> setNameFn) {
3336  setNameFn(getResult(), "transpose");
3337 }
3338 
3339 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3340 static MemRefType inferTransposeResultType(MemRefType memRefType,
3341  AffineMap permutationMap) {
3342  auto originalSizes = memRefType.getShape();
3343  auto [originalStrides, offset] = getStridesAndOffset(memRefType);
3344  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3345 
3346  // Compute permuted sizes and strides.
3347  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3348  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3349 
3350  return MemRefType::Builder(memRefType)
3351  .setShape(sizes)
3352  .setLayout(
3353  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3354 }
3355 
3356 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3357  AffineMapAttr permutation,
3358  ArrayRef<NamedAttribute> attrs) {
3359  auto permutationMap = permutation.getValue();
3360  assert(permutationMap);
3361 
3362  auto memRefType = llvm::cast<MemRefType>(in.getType());
3363  // Compute result type.
3364  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3365 
3366  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3367  build(b, result, resultType, in, attrs);
3368 }
3369 
3370 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3372  p << " " << getIn() << " " << getPermutation();
3373  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3374  p << " : " << getIn().getType() << " to " << getType();
3375 }
3376 
3377 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3379  AffineMap permutation;
3380  MemRefType srcType, dstType;
3381  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3382  parser.parseOptionalAttrDict(result.attributes) ||
3383  parser.parseColonType(srcType) ||
3384  parser.resolveOperand(in, srcType, result.operands) ||
3385  parser.parseKeywordType("to", dstType) ||
3386  parser.addTypeToList(dstType, result.types))
3387  return failure();
3388 
3389  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3390  AffineMapAttr::get(permutation));
3391  return success();
3392 }
3393 
3394 LogicalResult TransposeOp::verify() {
3395  if (!getPermutation().isPermutation())
3396  return emitOpError("expected a permutation map");
3397  if (getPermutation().getNumDims() != getIn().getType().getRank())
3398  return emitOpError("expected a permutation map of same rank as the input");
3399 
3400  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3401  auto resultType = llvm::cast<MemRefType>(getType());
3402  auto canonicalResultType = canonicalizeStridedLayout(
3403  inferTransposeResultType(srcType, getPermutation()));
3404 
3405  if (canonicalizeStridedLayout(resultType) != canonicalResultType)
3406  return emitOpError("result type ")
3407  << resultType
3408  << " is not equivalent to the canonical transposed input type "
3409  << canonicalResultType;
3410  return success();
3411 }
3412 
3413 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3414  // First check for identity permutation, we can fold it away if input and
3415  // result types are identical already.
3416  if (getPermutation().isIdentity() && getType() == getIn().getType())
3417  return getIn();
3418  // Fold two consecutive memref.transpose Ops into one by composing their
3419  // permutation maps.
3420  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3421  AffineMap composedPermutation =
3422  getPermutation().compose(otherTransposeOp.getPermutation());
3423  getInMutable().assign(otherTransposeOp.getIn());
3424  setPermutation(composedPermutation);
3425  return getResult();
3426  }
3427  return {};
3428 }
3429 
3430 //===----------------------------------------------------------------------===//
3431 // ViewOp
3432 //===----------------------------------------------------------------------===//
3433 
3434 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3435  setNameFn(getResult(), "view");
3436 }
3437 
3438 LogicalResult ViewOp::verify() {
3439  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3440  auto viewType = getType();
3441 
3442  // The base memref should have identity layout map (or none).
3443  if (!baseType.getLayout().isIdentity())
3444  return emitError("unsupported map for base memref type ") << baseType;
3445 
3446  // The result memref should have identity layout map (or none).
3447  if (!viewType.getLayout().isIdentity())
3448  return emitError("unsupported map for result memref type ") << viewType;
3449 
3450  // The base memref and the view memref should be in the same memory space.
3451  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3452  return emitError("different memory spaces specified for base memref "
3453  "type ")
3454  << baseType << " and view memref type " << viewType;
3455 
3456  // Verify that we have the correct number of sizes for the result type.
3457  unsigned numDynamicDims = viewType.getNumDynamicDims();
3458  if (getSizes().size() != numDynamicDims)
3459  return emitError("incorrect number of size operands for type ") << viewType;
3460 
3461  return success();
3462 }
3463 
3464 Value ViewOp::getViewSource() { return getSource(); }
3465 
3466 namespace {
3467 
3468 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3470 
3471  LogicalResult matchAndRewrite(ViewOp viewOp,
3472  PatternRewriter &rewriter) const override {
3473  // Return if none of the operands are constants.
3474  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3475  return matchPattern(operand, matchConstantIndex());
3476  }))
3477  return failure();
3478 
3479  // Get result memref type.
3480  auto memrefType = viewOp.getType();
3481 
3482  // Get offset from old memref view type 'memRefType'.
3483  int64_t oldOffset;
3484  SmallVector<int64_t, 4> oldStrides;
3485  if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3486  return failure();
3487  assert(oldOffset == 0 && "Expected 0 offset");
3488 
3489  SmallVector<Value, 4> newOperands;
3490 
3491  // Offset cannot be folded into result type.
3492 
3493  // Fold any dynamic dim operands which are produced by a constant.
3494  SmallVector<int64_t, 4> newShapeConstants;
3495  newShapeConstants.reserve(memrefType.getRank());
3496 
3497  unsigned dynamicDimPos = 0;
3498  unsigned rank = memrefType.getRank();
3499  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3500  int64_t dimSize = memrefType.getDimSize(dim);
3501  // If this is already static dimension, keep it.
3502  if (!ShapedType::isDynamic(dimSize)) {
3503  newShapeConstants.push_back(dimSize);
3504  continue;
3505  }
3506  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3507  if (auto constantIndexOp =
3508  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3509  // Dynamic shape dimension will be folded.
3510  newShapeConstants.push_back(constantIndexOp.value());
3511  } else {
3512  // Dynamic shape dimension not folded; copy operand from old memref.
3513  newShapeConstants.push_back(dimSize);
3514  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3515  }
3516  dynamicDimPos++;
3517  }
3518 
3519  // Create new memref type with constant folded dims.
3520  MemRefType newMemRefType =
3521  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3522  // Nothing new, don't fold.
3523  if (newMemRefType == memrefType)
3524  return failure();
3525 
3526  // Create new ViewOp.
3527  auto newViewOp = rewriter.create<ViewOp>(
3528  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3529  viewOp.getByteShift(), newOperands);
3530  // Insert a cast so we have the same type as the old memref type.
3531  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3532  return success();
3533  }
3534 };
3535 
3536 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3538 
3539  LogicalResult matchAndRewrite(ViewOp viewOp,
3540  PatternRewriter &rewriter) const override {
3541  Value memrefOperand = viewOp.getOperand(0);
3542  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3543  if (!memrefCastOp)
3544  return failure();
3545  Value allocOperand = memrefCastOp.getOperand();
3546  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3547  if (!allocOp)
3548  return failure();
3549  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3550  viewOp.getByteShift(),
3551  viewOp.getSizes());
3552  return success();
3553  }
3554 };
3555 
3556 } // namespace
3557 
3558 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3559  MLIRContext *context) {
3560  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3561 }
3562 
3563 //===----------------------------------------------------------------------===//
3564 // AtomicRMWOp
3565 //===----------------------------------------------------------------------===//
3566 
3567 LogicalResult AtomicRMWOp::verify() {
3568  if (getMemRefType().getRank() != getNumOperands() - 2)
3569  return emitOpError(
3570  "expects the number of subscripts to be equal to memref rank");
3571  switch (getKind()) {
3572  case arith::AtomicRMWKind::addf:
3573  case arith::AtomicRMWKind::maximumf:
3574  case arith::AtomicRMWKind::minimumf:
3575  case arith::AtomicRMWKind::mulf:
3576  if (!llvm::isa<FloatType>(getValue().getType()))
3577  return emitOpError() << "with kind '"
3578  << arith::stringifyAtomicRMWKind(getKind())
3579  << "' expects a floating-point type";
3580  break;
3581  case arith::AtomicRMWKind::addi:
3582  case arith::AtomicRMWKind::maxs:
3583  case arith::AtomicRMWKind::maxu:
3584  case arith::AtomicRMWKind::mins:
3585  case arith::AtomicRMWKind::minu:
3586  case arith::AtomicRMWKind::muli:
3587  case arith::AtomicRMWKind::ori:
3588  case arith::AtomicRMWKind::andi:
3589  if (!llvm::isa<IntegerType>(getValue().getType()))
3590  return emitOpError() << "with kind '"
3591  << arith::stringifyAtomicRMWKind(getKind())
3592  << "' expects an integer type";
3593  break;
3594  default:
3595  break;
3596  }
3597  return success();
3598 }
3599 
3600 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3601  /// atomicrmw(memrefcast) -> atomicrmw
3602  if (succeeded(foldMemRefCast(*this, getValue())))
3603  return getResult();
3604  return OpFoldResult();
3605 }
3606 
3607 //===----------------------------------------------------------------------===//
3608 // TableGen'd op method definitions
3609 //===----------------------------------------------------------------------===//
3610 
3611 #define GET_OP_CLASSES
3612 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:71
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
Definition: MemRefOps.cpp:163
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
Definition: MemRefOps.cpp:115
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1551
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2128
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:449
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3340
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:430
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2911
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1398
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3066
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
Definition: MemRefOps.cpp:2944
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
Definition: MemRefOps.cpp:176
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1565
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
Definition: MemRefOps.cpp:155
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:938
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3131
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:472
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2922
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:923
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2221
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2416
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:201
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:135
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:149
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:95
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:213
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:234
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:224
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:227
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:59
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:67
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3098
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:317
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:522
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:387
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3033
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:519
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:522
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:479
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:482
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2584
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3283
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3284
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3246
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3247
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.