MLIR  20.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 /// Materialize a single constant operation from a given attribute value with
30 /// the desired resultant type.
32  Attribute value, Type type,
33  Location loc) {
34  return arith::ConstantOp::materialize(builder, value, type, loc);
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Common canonicalization pattern support logic
39 //===----------------------------------------------------------------------===//
40 
41 /// This is a common class used for patterns of the form
42 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
43 /// into the root operation directly.
44 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
45  bool folded = false;
46  for (OpOperand &operand : op->getOpOperands()) {
47  auto cast = operand.get().getDefiningOp<CastOp>();
48  if (cast && operand.get() != inner &&
49  !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50  operand.set(cast.getOperand());
51  folded = true;
52  }
53  }
54  return success(folded);
55 }
56 
57 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
58 /// type.
60  if (auto memref = llvm::dyn_cast<MemRefType>(type))
61  return RankedTensorType::get(memref.getShape(), memref.getElementType());
62  if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
63  return UnrankedTensorType::get(memref.getElementType());
64  return NoneType::get(type.getContext());
65 }
66 
68  int64_t dim) {
69  auto memrefType = llvm::cast<MemRefType>(value.getType());
71  if (memrefType.isDynamicDim(dim))
72  return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74  return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
78  Location loc, Value value) {
79  auto memrefType = llvm::cast<MemRefType>(value.getType());
81  for (int64_t i = 0; i < memrefType.getRank(); ++i)
82  result.push_back(getMixedSize(builder, loc, value, i));
83  return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that infers the constant values from a list of \p values,
91 /// a \p memRefTy, and another helper function \p getAttributes.
92 /// The inferred constant values replace the related `OpFoldResult` in
93 /// \p values.
94 ///
95 /// \note This function shouldn't be used directly, instead, use the
96 /// `getConstifiedMixedXXX` methods from the related operations.
97 ///
98 /// \p getAttributes retuns a list of potentially constant values, as determined
99 /// by \p isDynamic, from the given \p memRefTy. The returned list must have as
100 /// many elements as \p values or be empty.
101 ///
102 /// E.g., consider the following example:
103 /// ```
104 /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
105 /// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
106 /// ```
107 /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
108 /// Now using this helper function with:
109 /// - `values == [2, %dyn_stride]`,
110 /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
111 /// - `getAttributes == getConstantStrides` (i.e., a wrapper around
112 /// `getStridesAndOffset`), and
113 /// - `isDynamic == ShapedType::isDynamic`
114 /// Will yield: `values == [2, 1]`
116  SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
117  MLIRContext *ctxt,
118  llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
119  llvm::function_ref<bool(int64_t)> isDynamic) {
120  SmallVector<int64_t> constValues = getAttributes(memRefTy);
121  Builder builder(ctxt);
122  for (const auto &it : llvm::enumerate(constValues)) {
123  int64_t constValue = it.value();
124  if (!isDynamic(constValue))
125  values[it.index()] = builder.getIndexAttr(constValue);
126  }
127  for (OpFoldResult &ofr : values) {
128  if (auto attr = dyn_cast<Attribute>(ofr)) {
129  // FIXME: We shouldn't need to do that, but right now, the static indices
130  // are created with the wrong type: `i64` instead of `index`.
131  // As a result, if we were to keep the attribute as is, we may fail to see
132  // that two attributes are equal because one would have the i64 type and
133  // the other the index type.
134  // The alternative would be to create constant indices with getI64Attr in
135  // this and the previous loop, but it doesn't logically make sense (we are
136  // dealing with indices here) and would only strenghten the inconsistency
137  // around how static indices are created (some places use getI64Attr,
138  // others use getIndexAttr).
139  // The workaround here is to stick to the IndexAttr type for all the
140  // values, hence we recreate the attribute even when it is already static
141  // to make sure the type is consistent.
142  ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
143  continue;
144  }
145  std::optional<int64_t> maybeConstant =
146  getConstantIntValue(cast<Value>(ofr));
147  if (maybeConstant)
148  ofr = builder.getIndexAttr(*maybeConstant);
149  }
150 }
151 
152 /// Wrapper around `getShape` that conforms to the function signature
153 /// expected for `getAttributes` in `constifyIndexValues`.
154 static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
155  ArrayRef<int64_t> sizes = memRefTy.getShape();
156  return SmallVector<int64_t>(sizes);
157 }
158 
159 /// Wrapper around `getStridesAndOffset` that returns only the offset and
160 /// conforms to the function signature expected for `getAttributes` in
161 /// `constifyIndexValues`.
162 static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
163  SmallVector<int64_t> strides;
164  int64_t offset;
165  LogicalResult hasStaticInformation =
166  getStridesAndOffset(memrefType, strides, offset);
167  if (failed(hasStaticInformation))
168  return SmallVector<int64_t>();
169  return SmallVector<int64_t>(1, offset);
170 }
171 
172 /// Wrapper around `getStridesAndOffset` that returns only the strides and
173 /// conforms to the function signature expected for `getAttributes` in
174 /// `constifyIndexValues`.
175 static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
176  SmallVector<int64_t> strides;
177  int64_t offset;
178  LogicalResult hasStaticInformation =
179  getStridesAndOffset(memrefType, strides, offset);
180  if (failed(hasStaticInformation))
181  return SmallVector<int64_t>();
182  return strides;
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // AllocOp / AllocaOp
187 //===----------------------------------------------------------------------===//
188 
189 void AllocOp::getAsmResultNames(
190  function_ref<void(Value, StringRef)> setNameFn) {
191  setNameFn(getResult(), "alloc");
192 }
193 
194 void AllocaOp::getAsmResultNames(
195  function_ref<void(Value, StringRef)> setNameFn) {
196  setNameFn(getResult(), "alloca");
197 }
198 
199 template <typename AllocLikeOp>
200 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
201  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
202  "applies to only alloc or alloca");
203  auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
204  if (!memRefType)
205  return op.emitOpError("result must be a memref");
206 
207  if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
208  return op.emitOpError("dimension operand count does not equal memref "
209  "dynamic dimension count");
210 
211  unsigned numSymbols = 0;
212  if (!memRefType.getLayout().isIdentity())
213  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
214  if (op.getSymbolOperands().size() != numSymbols)
215  return op.emitOpError("symbol operand count does not equal memref symbol "
216  "count: expected ")
217  << numSymbols << ", got " << op.getSymbolOperands().size();
218 
219  return success();
220 }
221 
222 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
223 
224 LogicalResult AllocaOp::verify() {
225  // An alloca op needs to have an ancestor with an allocation scope trait.
226  if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
227  return emitOpError(
228  "requires an ancestor op with AutomaticAllocationScope trait");
229 
230  return verifyAllocLikeOp(*this);
231 }
232 
233 namespace {
234 /// Fold constant dimensions into an alloc like operation.
235 template <typename AllocLikeOp>
236 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
238 
239  LogicalResult matchAndRewrite(AllocLikeOp alloc,
240  PatternRewriter &rewriter) const override {
241  // Check to see if any dimensions operands are constants. If so, we can
242  // substitute and drop them.
243  if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
244  APInt constSizeArg;
245  if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
246  return false;
247  return constSizeArg.isNonNegative();
248  }))
249  return failure();
250 
251  auto memrefType = alloc.getType();
252 
253  // Ok, we have one or more constant operands. Collect the non-constant ones
254  // and keep track of the resultant memref type to build.
255  SmallVector<int64_t, 4> newShapeConstants;
256  newShapeConstants.reserve(memrefType.getRank());
257  SmallVector<Value, 4> dynamicSizes;
258 
259  unsigned dynamicDimPos = 0;
260  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
261  int64_t dimSize = memrefType.getDimSize(dim);
262  // If this is already static dimension, keep it.
263  if (!ShapedType::isDynamic(dimSize)) {
264  newShapeConstants.push_back(dimSize);
265  continue;
266  }
267  auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
268  APInt constSizeArg;
269  if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
270  constSizeArg.isNonNegative()) {
271  // Dynamic shape dimension will be folded.
272  newShapeConstants.push_back(constSizeArg.getZExtValue());
273  } else {
274  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
275  newShapeConstants.push_back(ShapedType::kDynamic);
276  dynamicSizes.push_back(dynamicSize);
277  }
278  dynamicDimPos++;
279  }
280 
281  // Create new memref type (which will have fewer dynamic dimensions).
282  MemRefType newMemRefType =
283  MemRefType::Builder(memrefType).setShape(newShapeConstants);
284  assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
285 
286  // Create and insert the alloc op for the new memref.
287  auto newAlloc = rewriter.create<AllocLikeOp>(
288  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
289  alloc.getAlignmentAttr());
290  // Insert a cast so we have the same type as the old alloc.
291  rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
292  return success();
293  }
294 };
295 
296 /// Fold alloc operations with no users or only store and dealloc uses.
297 template <typename T>
298 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
300 
301  LogicalResult matchAndRewrite(T alloc,
302  PatternRewriter &rewriter) const override {
303  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
304  if (auto storeOp = dyn_cast<StoreOp>(op))
305  return storeOp.getValue() == alloc;
306  return !isa<DeallocOp>(op);
307  }))
308  return failure();
309 
310  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
311  rewriter.eraseOp(user);
312 
313  rewriter.eraseOp(alloc);
314  return success();
315  }
316 };
317 } // namespace
318 
319 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
320  MLIRContext *context) {
321  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
322 }
323 
324 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
325  MLIRContext *context) {
326  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
327  context);
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // ReallocOp
332 //===----------------------------------------------------------------------===//
333 
334 LogicalResult ReallocOp::verify() {
335  auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
336  MemRefType resultType = getType();
337 
338  // The source memref should have identity layout (or none).
339  if (!sourceType.getLayout().isIdentity())
340  return emitError("unsupported layout for source memref type ")
341  << sourceType;
342 
343  // The result memref should have identity layout (or none).
344  if (!resultType.getLayout().isIdentity())
345  return emitError("unsupported layout for result memref type ")
346  << resultType;
347 
348  // The source memref and the result memref should be in the same memory space.
349  if (sourceType.getMemorySpace() != resultType.getMemorySpace())
350  return emitError("different memory spaces specified for source memref "
351  "type ")
352  << sourceType << " and result memref type " << resultType;
353 
354  // The source memref and the result memref should have the same element type.
355  if (sourceType.getElementType() != resultType.getElementType())
356  return emitError("different element types specified for source memref "
357  "type ")
358  << sourceType << " and result memref type " << resultType;
359 
360  // Verify that we have the dynamic dimension operand when it is needed.
361  if (resultType.getNumDynamicDims() && !getDynamicResultSize())
362  return emitError("missing dimension operand for result type ")
363  << resultType;
364  if (!resultType.getNumDynamicDims() && getDynamicResultSize())
365  return emitError("unnecessary dimension operand for result type ")
366  << resultType;
367 
368  return success();
369 }
370 
371 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
372  MLIRContext *context) {
373  results.add<SimplifyDeadAlloc<ReallocOp>>(context);
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // AllocaScopeOp
378 //===----------------------------------------------------------------------===//
379 
381  bool printBlockTerminators = false;
382 
383  p << ' ';
384  if (!getResults().empty()) {
385  p << " -> (" << getResultTypes() << ")";
386  printBlockTerminators = true;
387  }
388  p << ' ';
389  p.printRegion(getBodyRegion(),
390  /*printEntryBlockArgs=*/false,
391  /*printBlockTerminators=*/printBlockTerminators);
392  p.printOptionalAttrDict((*this)->getAttrs());
393 }
394 
395 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
396  // Create a region for the body.
397  result.regions.reserve(1);
398  Region *bodyRegion = result.addRegion();
399 
400  // Parse optional results type list.
401  if (parser.parseOptionalArrowTypeList(result.types))
402  return failure();
403 
404  // Parse the body region.
405  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
406  return failure();
407  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
408  result.location);
409 
410  // Parse the optional attribute list.
411  if (parser.parseOptionalAttrDict(result.attributes))
412  return failure();
413 
414  return success();
415 }
416 
417 void AllocaScopeOp::getSuccessorRegions(
419  if (!point.isParent()) {
420  regions.push_back(RegionSuccessor(getResults()));
421  return;
422  }
423 
424  regions.push_back(RegionSuccessor(&getBodyRegion()));
425 }
426 
427 /// Given an operation, return whether this op is guaranteed to
428 /// allocate an AutomaticAllocationScopeResource
430  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
431  if (!interface)
432  return false;
433  for (auto res : op->getResults()) {
434  if (auto effect =
435  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
436  if (isa<SideEffects::AutomaticAllocationScopeResource>(
437  effect->getResource()))
438  return true;
439  }
440  }
441  return false;
442 }
443 
444 /// Given an operation, return whether this op itself could
445 /// allocate an AutomaticAllocationScopeResource. Note that
446 /// this will not check whether an operation contained within
447 /// the op can allocate.
449  // This op itself doesn't create a stack allocation,
450  // the inner allocation should be handled separately.
452  return false;
453  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
454  if (!interface)
455  return true;
456  for (auto res : op->getResults()) {
457  if (auto effect =
458  interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
459  if (isa<SideEffects::AutomaticAllocationScopeResource>(
460  effect->getResource()))
461  return true;
462  }
463  }
464  return false;
465 }
466 
467 /// Return whether this op is the last non terminating op
468 /// in a region. That is to say, it is in a one-block region
469 /// and is only followed by a terminator. This prevents
470 /// extending the lifetime of allocations.
472  return op->getNextNode() == op->getBlock()->getTerminator() &&
473  op->getParentRegion()->getBlocks().size() == 1;
474 }
475 
476 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
477 /// or it contains no allocation.
478 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
480 
481  LogicalResult matchAndRewrite(AllocaScopeOp op,
482  PatternRewriter &rewriter) const override {
483  bool hasPotentialAlloca =
484  op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
485  if (alloc == op)
486  return WalkResult::advance();
488  return WalkResult::interrupt();
489  if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
490  return WalkResult::skip();
491  return WalkResult::advance();
492  }).wasInterrupted();
493 
494  // If this contains no potential allocation, it is always legal to
495  // inline. Otherwise, consider two conditions:
496  if (hasPotentialAlloca) {
497  // If the parent isn't an allocation scope, or we are not the last
498  // non-terminator op in the parent, we will extend the lifetime.
499  if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
500  return failure();
501  if (!lastNonTerminatorInRegion(op))
502  return failure();
503  }
504 
505  Block *block = &op.getRegion().front();
506  Operation *terminator = block->getTerminator();
507  ValueRange results = terminator->getOperands();
508  rewriter.inlineBlockBefore(block, op);
509  rewriter.replaceOp(op, results);
510  rewriter.eraseOp(terminator);
511  return success();
512  }
513 };
514 
515 /// Move allocations into an allocation scope, if it is legal to
516 /// move them (e.g. their operands are available at the location
517 /// the op would be moved to).
518 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
520 
521  LogicalResult matchAndRewrite(AllocaScopeOp op,
522  PatternRewriter &rewriter) const override {
523 
524  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
525  return failure();
526 
527  Operation *lastParentWithoutScope = op->getParentOp();
528 
529  if (!lastParentWithoutScope ||
530  lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
531  return failure();
532 
533  // Only apply to if this is this last non-terminator
534  // op in the block (lest lifetime be extended) of a one
535  // block region
536  if (!lastNonTerminatorInRegion(op) ||
537  !lastNonTerminatorInRegion(lastParentWithoutScope))
538  return failure();
539 
540  while (!lastParentWithoutScope->getParentOp()
542  lastParentWithoutScope = lastParentWithoutScope->getParentOp();
543  if (!lastParentWithoutScope ||
544  !lastNonTerminatorInRegion(lastParentWithoutScope))
545  return failure();
546  }
547  assert(lastParentWithoutScope->getParentOp()
549 
550  Region *containingRegion = nullptr;
551  for (auto &r : lastParentWithoutScope->getRegions()) {
552  if (r.isAncestor(op->getParentRegion())) {
553  assert(containingRegion == nullptr &&
554  "only one region can contain the op");
555  containingRegion = &r;
556  }
557  }
558  assert(containingRegion && "op must be contained in a region");
559 
560  SmallVector<Operation *> toHoist;
561  op->walk([&](Operation *alloc) {
563  return WalkResult::skip();
564 
565  // If any operand is not defined before the location of
566  // lastParentWithoutScope (i.e. where we would hoist to), skip.
567  if (llvm::any_of(alloc->getOperands(), [&](Value v) {
568  return containingRegion->isAncestor(v.getParentRegion());
569  }))
570  return WalkResult::skip();
571  toHoist.push_back(alloc);
572  return WalkResult::advance();
573  });
574 
575  if (toHoist.empty())
576  return failure();
577  rewriter.setInsertionPoint(lastParentWithoutScope);
578  for (auto *op : toHoist) {
579  auto *cloned = rewriter.clone(*op);
580  rewriter.replaceOp(op, cloned->getResults());
581  }
582  return success();
583  }
584 };
585 
586 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
587  MLIRContext *context) {
588  results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
589 }
590 
591 //===----------------------------------------------------------------------===//
592 // AssumeAlignmentOp
593 //===----------------------------------------------------------------------===//
594 
595 LogicalResult AssumeAlignmentOp::verify() {
596  if (!llvm::isPowerOf2_32(getAlignment()))
597  return emitOpError("alignment must be power of 2");
598  return success();
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // CastOp
603 //===----------------------------------------------------------------------===//
604 
605 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
606  setNameFn(getResult(), "cast");
607 }
608 
609 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
610 /// source memref. This is useful to fold a memref.cast into a consuming op
611 /// and implement canonicalization patterns for ops in different dialects that
612 /// may consume the results of memref.cast operations. Such foldable memref.cast
613 /// operations are typically inserted as `view` and `subview` ops are
614 /// canonicalized, to preserve the type compatibility of their uses.
615 ///
616 /// Returns true when all conditions are met:
617 /// 1. source and result are ranked memrefs with strided semantics and same
618 /// element type and rank.
619 /// 2. each of the source's size, offset or stride has more static information
620 /// than the corresponding result's size, offset or stride.
621 ///
622 /// Example 1:
623 /// ```mlir
624 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
625 /// %2 = consumer %1 ... : memref<?x?xf32> ...
626 /// ```
627 ///
628 /// may fold into:
629 ///
630 /// ```mlir
631 /// %2 = consumer %0 ... : memref<8x16xf32> ...
632 /// ```
633 ///
634 /// Example 2:
635 /// ```
636 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
637 /// to memref<?x?xf32>
638 /// consumer %1 : memref<?x?xf32> ...
639 /// ```
640 ///
641 /// may fold into:
642 ///
643 /// ```
644 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
645 /// ```
646 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
647  MemRefType sourceType =
648  llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
649  MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
650 
651  // Requires ranked MemRefType.
652  if (!sourceType || !resultType)
653  return false;
654 
655  // Requires same elemental type.
656  if (sourceType.getElementType() != resultType.getElementType())
657  return false;
658 
659  // Requires same rank.
660  if (sourceType.getRank() != resultType.getRank())
661  return false;
662 
663  // Only fold casts between strided memref forms.
664  int64_t sourceOffset, resultOffset;
665  SmallVector<int64_t, 4> sourceStrides, resultStrides;
666  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
667  failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
668  return false;
669 
670  // If cast is towards more static sizes along any dimension, don't fold.
671  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
672  auto ss = std::get<0>(it), st = std::get<1>(it);
673  if (ss != st)
674  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
675  return false;
676  }
677 
678  // If cast is towards more static offset along any dimension, don't fold.
679  if (sourceOffset != resultOffset)
680  if (ShapedType::isDynamic(sourceOffset) &&
681  !ShapedType::isDynamic(resultOffset))
682  return false;
683 
684  // If cast is towards more static strides along any dimension, don't fold.
685  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
686  auto ss = std::get<0>(it), st = std::get<1>(it);
687  if (ss != st)
688  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
689  return false;
690  }
691 
692  return true;
693 }
694 
695 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
696  if (inputs.size() != 1 || outputs.size() != 1)
697  return false;
698  Type a = inputs.front(), b = outputs.front();
699  auto aT = llvm::dyn_cast<MemRefType>(a);
700  auto bT = llvm::dyn_cast<MemRefType>(b);
701 
702  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
703  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
704 
705  if (aT && bT) {
706  if (aT.getElementType() != bT.getElementType())
707  return false;
708  if (aT.getLayout() != bT.getLayout()) {
709  int64_t aOffset, bOffset;
710  SmallVector<int64_t, 4> aStrides, bStrides;
711  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
712  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
713  aStrides.size() != bStrides.size())
714  return false;
715 
716  // Strides along a dimension/offset are compatible if the value in the
717  // source memref is static and the value in the target memref is the
718  // same. They are also compatible if either one is dynamic (see
719  // description of MemRefCastOp for details).
720  auto checkCompatible = [](int64_t a, int64_t b) {
721  return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
722  };
723  if (!checkCompatible(aOffset, bOffset))
724  return false;
725  for (const auto &aStride : enumerate(aStrides))
726  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
727  return false;
728  }
729  if (aT.getMemorySpace() != bT.getMemorySpace())
730  return false;
731 
732  // They must have the same rank, and any specified dimensions must match.
733  if (aT.getRank() != bT.getRank())
734  return false;
735 
736  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
737  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
738  if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
739  aDim != bDim)
740  return false;
741  }
742  return true;
743  } else {
744  if (!aT && !uaT)
745  return false;
746  if (!bT && !ubT)
747  return false;
748  // Unranked to unranked casting is unsupported
749  if (uaT && ubT)
750  return false;
751 
752  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
753  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
754  if (aEltType != bEltType)
755  return false;
756 
757  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
758  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
759  return aMemSpace == bMemSpace;
760  }
761 
762  return false;
763 }
764 
765 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
766  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
767 }
768 
769 //===----------------------------------------------------------------------===//
770 // CopyOp
771 //===----------------------------------------------------------------------===//
772 
773 namespace {
774 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
775 /// and element type, the cast can be skipped. Such CastOps only cast the layout
776 /// of the type.
777 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
779 
780  LogicalResult matchAndRewrite(CopyOp copyOp,
781  PatternRewriter &rewriter) const override {
782  bool modified = false;
783 
784  // Check source.
785  if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
786  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
787  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
788 
789  if (fromType && toType) {
790  if (fromType.getShape() == toType.getShape() &&
791  fromType.getElementType() == toType.getElementType()) {
792  rewriter.modifyOpInPlace(copyOp, [&] {
793  copyOp.getSourceMutable().assign(castOp.getSource());
794  });
795  modified = true;
796  }
797  }
798  }
799 
800  // Check target.
801  if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
802  auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
803  auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
804 
805  if (fromType && toType) {
806  if (fromType.getShape() == toType.getShape() &&
807  fromType.getElementType() == toType.getElementType()) {
808  rewriter.modifyOpInPlace(copyOp, [&] {
809  copyOp.getTargetMutable().assign(castOp.getSource());
810  });
811  modified = true;
812  }
813  }
814  }
815 
816  return success(modified);
817  }
818 };
819 
820 /// Fold memref.copy(%x, %x).
821 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
823 
824  LogicalResult matchAndRewrite(CopyOp copyOp,
825  PatternRewriter &rewriter) const override {
826  if (copyOp.getSource() != copyOp.getTarget())
827  return failure();
828 
829  rewriter.eraseOp(copyOp);
830  return success();
831  }
832 };
833 
834 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
836 
837  static bool isEmptyMemRef(BaseMemRefType type) {
838  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
839  }
840 
841  LogicalResult matchAndRewrite(CopyOp copyOp,
842  PatternRewriter &rewriter) const override {
843  if (isEmptyMemRef(copyOp.getSource().getType()) ||
844  isEmptyMemRef(copyOp.getTarget().getType())) {
845  rewriter.eraseOp(copyOp);
846  return success();
847  }
848 
849  return failure();
850  }
851 };
852 } // namespace
853 
854 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
855  MLIRContext *context) {
856  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
857 }
858 
859 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
861  /// copy(memrefcast) -> copy
862  bool folded = false;
863  Operation *op = *this;
864  for (OpOperand &operand : op->getOpOperands()) {
865  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
866  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
867  operand.set(castOp.getOperand());
868  folded = true;
869  }
870  }
871  return success(folded);
872 }
873 
874 //===----------------------------------------------------------------------===//
875 // DeallocOp
876 //===----------------------------------------------------------------------===//
877 
878 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
880  /// dealloc(memrefcast) -> dealloc
881  return foldMemRefCast(*this);
882 }
883 
884 //===----------------------------------------------------------------------===//
885 // DimOp
886 //===----------------------------------------------------------------------===//
887 
888 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
889  setNameFn(getResult(), "dim");
890 }
891 
892 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
893  int64_t index) {
894  auto loc = result.location;
895  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
896  build(builder, result, source, indexValue);
897 }
898 
899 std::optional<int64_t> DimOp::getConstantIndex() {
900  return getConstantIntValue(getIndex());
901 }
902 
903 Speculation::Speculatability DimOp::getSpeculatability() {
904  auto constantIndex = getConstantIndex();
905  if (!constantIndex)
907 
908  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
909  if (!rankedSourceType)
911 
912  if (rankedSourceType.getRank() <= constantIndex)
914 
916 }
917 
918 /// Return a map with key being elements in `vals` and data being number of
919 /// occurences of it. Use std::map, since the `vals` here are strides and the
920 /// dynamic stride value is the same as the tombstone value for
921 /// `DenseMap<int64_t>`.
922 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
923  std::map<int64_t, unsigned> numOccurences;
924  for (auto val : vals)
925  numOccurences[val]++;
926  return numOccurences;
927 }
928 
929 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
930 /// to be a subset of `originalType` with some `1` entries erased, return the
931 /// set of indices that specifies which of the entries of `originalShape` are
932 /// dropped to obtain `reducedShape`.
933 /// This accounts for cases where there are multiple unit-dims, but only a
934 /// subset of those are dropped. For MemRefTypes these can be disambiguated
935 /// using the strides. If a dimension is dropped the stride must be dropped too.
936 static FailureOr<llvm::SmallBitVector>
937 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
938  ArrayRef<OpFoldResult> sizes) {
939  llvm::SmallBitVector unusedDims(originalType.getRank());
940  if (originalType.getRank() == reducedType.getRank())
941  return unusedDims;
942 
943  for (const auto &dim : llvm::enumerate(sizes))
944  if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
945  if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
946  unusedDims.set(dim.index());
947 
948  // Early exit for the case where the number of unused dims matches the number
949  // of ranks reduced.
950  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
951  originalType.getRank())
952  return unusedDims;
953 
954  SmallVector<int64_t> originalStrides, candidateStrides;
955  int64_t originalOffset, candidateOffset;
956  if (failed(
957  getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
958  failed(
959  getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
960  return failure();
961 
962  // For memrefs, a dimension is truly dropped if its corresponding stride is
963  // also dropped. This is particularly important when more than one of the dims
964  // is 1. Track the number of occurences of the strides in the original type
965  // and the candidate type. For each unused dim that stride should not be
966  // present in the candidate type. Note that there could be multiple dimensions
967  // that have the same size. We dont need to exactly figure out which dim
968  // corresponds to which stride, we just need to verify that the number of
969  // reptitions of a stride in the original + number of unused dims with that
970  // stride == number of repititions of a stride in the candidate.
971  std::map<int64_t, unsigned> currUnaccountedStrides =
972  getNumOccurences(originalStrides);
973  std::map<int64_t, unsigned> candidateStridesNumOccurences =
974  getNumOccurences(candidateStrides);
975  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
976  if (!unusedDims.test(dim))
977  continue;
978  int64_t originalStride = originalStrides[dim];
979  if (currUnaccountedStrides[originalStride] >
980  candidateStridesNumOccurences[originalStride]) {
981  // This dim can be treated as dropped.
982  currUnaccountedStrides[originalStride]--;
983  continue;
984  }
985  if (currUnaccountedStrides[originalStride] ==
986  candidateStridesNumOccurences[originalStride]) {
987  // The stride for this is not dropped. Keep as is.
988  unusedDims.reset(dim);
989  continue;
990  }
991  if (currUnaccountedStrides[originalStride] <
992  candidateStridesNumOccurences[originalStride]) {
993  // This should never happen. Cant have a stride in the reduced rank type
994  // that wasnt in the original one.
995  return failure();
996  }
997  }
998 
999  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1000  originalType.getRank())
1001  return failure();
1002  return unusedDims;
1003 }
1004 
1005 llvm::SmallBitVector SubViewOp::getDroppedDims() {
1006  MemRefType sourceType = getSourceType();
1007  MemRefType resultType = getType();
1008  FailureOr<llvm::SmallBitVector> unusedDims =
1009  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1010  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1011  return *unusedDims;
1012 }
1013 
1014 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1015  // All forms of folding require a known index.
1016  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1017  if (!index)
1018  return {};
1019 
1020  // Folding for unranked types (UnrankedMemRefType) is not supported.
1021  auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1022  if (!memrefType)
1023  return {};
1024 
1025  // Out of bound indices produce undefined behavior but are still valid IR.
1026  // Don't choke on them.
1027  int64_t indexVal = index.getInt();
1028  if (indexVal < 0 || indexVal >= memrefType.getRank())
1029  return {};
1030 
1031  // Fold if the shape extent along the given index is known.
1032  if (!memrefType.isDynamicDim(index.getInt())) {
1033  Builder builder(getContext());
1034  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1035  }
1036 
1037  // The size at the given index is now known to be a dynamic size.
1038  unsigned unsignedIndex = index.getValue().getZExtValue();
1039 
1040  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1041  Operation *definingOp = getSource().getDefiningOp();
1042 
1043  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1044  return *(alloc.getDynamicSizes().begin() +
1045  memrefType.getDynamicDimIndex(unsignedIndex));
1046 
1047  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1048  return *(alloca.getDynamicSizes().begin() +
1049  memrefType.getDynamicDimIndex(unsignedIndex));
1050 
1051  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1052  return *(view.getDynamicSizes().begin() +
1053  memrefType.getDynamicDimIndex(unsignedIndex));
1054 
1055  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1056  llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1057  unsigned resultIndex = 0;
1058  unsigned sourceRank = subview.getSourceType().getRank();
1059  unsigned sourceIndex = 0;
1060  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1061  if (unusedDims.test(i))
1062  continue;
1063  if (resultIndex == unsignedIndex) {
1064  sourceIndex = i;
1065  break;
1066  }
1067  resultIndex++;
1068  }
1069  assert(subview.isDynamicSize(sourceIndex) &&
1070  "expected dynamic subview size");
1071  return subview.getDynamicSize(sourceIndex);
1072  }
1073 
1074  if (auto sizeInterface =
1075  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1076  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1077  "Expected dynamic subview size");
1078  return sizeInterface.getDynamicSize(unsignedIndex);
1079  }
1080 
1081  // dim(memrefcast) -> dim
1082  if (succeeded(foldMemRefCast(*this)))
1083  return getResult();
1084 
1085  return {};
1086 }
1087 
1088 namespace {
1089 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1090 /// operand.
1091 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1093 
1094  LogicalResult matchAndRewrite(DimOp dim,
1095  PatternRewriter &rewriter) const override {
1096  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1097 
1098  if (!reshape)
1099  return rewriter.notifyMatchFailure(
1100  dim, "Dim op is not defined by a reshape op.");
1101 
1102  // dim of a memref reshape can be folded if dim.getIndex() dominates the
1103  // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1104  // cheaply check that either of the following conditions hold:
1105  // 1. dim.getIndex() is defined in the same block as reshape but before
1106  // reshape.
1107  // 2. dim.getIndex() is defined in a parent block of
1108  // reshape.
1109 
1110  // Check condition 1
1111  if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1112  if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1113  if (reshape->isBeforeInBlock(definingOp)) {
1114  return rewriter.notifyMatchFailure(
1115  dim,
1116  "dim.getIndex is not defined before reshape in the same block.");
1117  }
1118  } // else dim.getIndex is a block argument to reshape->getBlock and
1119  // dominates reshape
1120  } // Check condition 2
1121  else if (dim->getBlock() != reshape->getBlock() &&
1122  !dim.getIndex().getParentRegion()->isProperAncestor(
1123  reshape->getParentRegion())) {
1124  // If dim and reshape are in the same block but dim.getIndex() isn't, we
1125  // already know dim.getIndex() dominates reshape without calling
1126  // `isProperAncestor`
1127  return rewriter.notifyMatchFailure(
1128  dim, "dim.getIndex does not dominate reshape.");
1129  }
1130 
1131  // Place the load directly after the reshape to ensure that the shape memref
1132  // was not mutated.
1133  rewriter.setInsertionPointAfter(reshape);
1134  Location loc = dim.getLoc();
1135  Value load =
1136  rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1137  if (load.getType() != dim.getType())
1138  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1139  rewriter.replaceOp(dim, load);
1140  return success();
1141  }
1142 };
1143 
1144 } // namespace
1145 
1146 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1147  MLIRContext *context) {
1148  results.add<DimOfMemRefReshape>(context);
1149 }
1150 
1151 // ---------------------------------------------------------------------------
1152 // DmaStartOp
1153 // ---------------------------------------------------------------------------
1154 
1155 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1156  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1157  ValueRange destIndices, Value numElements,
1158  Value tagMemRef, ValueRange tagIndices, Value stride,
1159  Value elementsPerStride) {
1160  result.addOperands(srcMemRef);
1161  result.addOperands(srcIndices);
1162  result.addOperands(destMemRef);
1163  result.addOperands(destIndices);
1164  result.addOperands({numElements, tagMemRef});
1165  result.addOperands(tagIndices);
1166  if (stride)
1167  result.addOperands({stride, elementsPerStride});
1168 }
1169 
1171  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1172  << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1173  << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1174  if (isStrided())
1175  p << ", " << getStride() << ", " << getNumElementsPerStride();
1176 
1177  p.printOptionalAttrDict((*this)->getAttrs());
1178  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1179  << ", " << getTagMemRef().getType();
1180 }
1181 
1182 // Parse DmaStartOp.
1183 // Ex:
1184 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1185 // %tag[%index], %stride, %num_elt_per_stride :
1186 // : memref<3076 x f32, 0>,
1187 // memref<1024 x f32, 2>,
1188 // memref<1 x i32>
1189 //
1190 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1191  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1193  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1195  OpAsmParser::UnresolvedOperand numElementsInfo;
1196  OpAsmParser::UnresolvedOperand tagMemrefInfo;
1199 
1200  SmallVector<Type, 3> types;
1201  auto indexType = parser.getBuilder().getIndexType();
1202 
1203  // Parse and resolve the following list of operands:
1204  // *) source memref followed by its indices (in square brackets).
1205  // *) destination memref followed by its indices (in square brackets).
1206  // *) dma size in KiB.
1207  if (parser.parseOperand(srcMemRefInfo) ||
1208  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1209  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1210  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1211  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1212  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1213  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1214  return failure();
1215 
1216  // Parse optional stride and elements per stride.
1217  if (parser.parseTrailingOperandList(strideInfo))
1218  return failure();
1219 
1220  bool isStrided = strideInfo.size() == 2;
1221  if (!strideInfo.empty() && !isStrided) {
1222  return parser.emitError(parser.getNameLoc(),
1223  "expected two stride related operands");
1224  }
1225 
1226  if (parser.parseColonTypeList(types))
1227  return failure();
1228  if (types.size() != 3)
1229  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1230 
1231  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1232  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1233  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1234  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1235  // size should be an index.
1236  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1237  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1238  // tag indices should be index.
1239  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1240  return failure();
1241 
1242  if (isStrided) {
1243  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1244  return failure();
1245  }
1246 
1247  return success();
1248 }
1249 
1250 LogicalResult DmaStartOp::verify() {
1251  unsigned numOperands = getNumOperands();
1252 
1253  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1254  // the number of elements.
1255  if (numOperands < 4)
1256  return emitOpError("expected at least 4 operands");
1257 
1258  // Check types of operands. The order of these calls is important: the later
1259  // calls rely on some type properties to compute the operand position.
1260  // 1. Source memref.
1261  if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1262  return emitOpError("expected source to be of memref type");
1263  if (numOperands < getSrcMemRefRank() + 4)
1264  return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1265  << " operands";
1266  if (!getSrcIndices().empty() &&
1267  !llvm::all_of(getSrcIndices().getTypes(),
1268  [](Type t) { return t.isIndex(); }))
1269  return emitOpError("expected source indices to be of index type");
1270 
1271  // 2. Destination memref.
1272  if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1273  return emitOpError("expected destination to be of memref type");
1274  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1275  if (numOperands < numExpectedOperands)
1276  return emitOpError() << "expected at least " << numExpectedOperands
1277  << " operands";
1278  if (!getDstIndices().empty() &&
1279  !llvm::all_of(getDstIndices().getTypes(),
1280  [](Type t) { return t.isIndex(); }))
1281  return emitOpError("expected destination indices to be of index type");
1282 
1283  // 3. Number of elements.
1284  if (!getNumElements().getType().isIndex())
1285  return emitOpError("expected num elements to be of index type");
1286 
1287  // 4. Tag memref.
1288  if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1289  return emitOpError("expected tag to be of memref type");
1290  numExpectedOperands += getTagMemRefRank();
1291  if (numOperands < numExpectedOperands)
1292  return emitOpError() << "expected at least " << numExpectedOperands
1293  << " operands";
1294  if (!getTagIndices().empty() &&
1295  !llvm::all_of(getTagIndices().getTypes(),
1296  [](Type t) { return t.isIndex(); }))
1297  return emitOpError("expected tag indices to be of index type");
1298 
1299  // Optional stride-related operands must be either both present or both
1300  // absent.
1301  if (numOperands != numExpectedOperands &&
1302  numOperands != numExpectedOperands + 2)
1303  return emitOpError("incorrect number of operands");
1304 
1305  // 5. Strides.
1306  if (isStrided()) {
1307  if (!getStride().getType().isIndex() ||
1308  !getNumElementsPerStride().getType().isIndex())
1309  return emitOpError(
1310  "expected stride and num elements per stride to be of type index");
1311  }
1312 
1313  return success();
1314 }
1315 
1316 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1317  SmallVectorImpl<OpFoldResult> &results) {
1318  /// dma_start(memrefcast) -> dma_start
1319  return foldMemRefCast(*this);
1320 }
1321 
1322 // ---------------------------------------------------------------------------
1323 // DmaWaitOp
1324 // ---------------------------------------------------------------------------
1325 
1326 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1327  SmallVectorImpl<OpFoldResult> &results) {
1328  /// dma_wait(memrefcast) -> dma_wait
1329  return foldMemRefCast(*this);
1330 }
1331 
1332 LogicalResult DmaWaitOp::verify() {
1333  // Check that the number of tag indices matches the tagMemRef rank.
1334  unsigned numTagIndices = getTagIndices().size();
1335  unsigned tagMemRefRank = getTagMemRefRank();
1336  if (numTagIndices != tagMemRefRank)
1337  return emitOpError() << "expected tagIndices to have the same number of "
1338  "elements as the tagMemRef rank, expected "
1339  << tagMemRefRank << ", but got " << numTagIndices;
1340  return success();
1341 }
1342 
1343 //===----------------------------------------------------------------------===//
1344 // ExtractAlignedPointerAsIndexOp
1345 //===----------------------------------------------------------------------===//
1346 
1347 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1348  function_ref<void(Value, StringRef)> setNameFn) {
1349  setNameFn(getResult(), "intptr");
1350 }
1351 
1352 //===----------------------------------------------------------------------===//
1353 // ExtractStridedMetadataOp
1354 //===----------------------------------------------------------------------===//
1355 
1356 /// The number and type of the results are inferred from the
1357 /// shape of the source.
1358 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1359  MLIRContext *context, std::optional<Location> location,
1360  ExtractStridedMetadataOp::Adaptor adaptor,
1361  SmallVectorImpl<Type> &inferredReturnTypes) {
1362  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1363  if (!sourceType)
1364  return failure();
1365 
1366  unsigned sourceRank = sourceType.getRank();
1367  IndexType indexType = IndexType::get(context);
1368  auto memrefType =
1369  MemRefType::get({}, sourceType.getElementType(),
1370  MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1371  // Base.
1372  inferredReturnTypes.push_back(memrefType);
1373  // Offset.
1374  inferredReturnTypes.push_back(indexType);
1375  // Sizes and strides.
1376  for (unsigned i = 0; i < sourceRank * 2; ++i)
1377  inferredReturnTypes.push_back(indexType);
1378  return success();
1379 }
1380 
1381 void ExtractStridedMetadataOp::getAsmResultNames(
1382  function_ref<void(Value, StringRef)> setNameFn) {
1383  setNameFn(getBaseBuffer(), "base_buffer");
1384  setNameFn(getOffset(), "offset");
1385  // For multi-result to work properly with pretty names and packed syntax `x:3`
1386  // we can only give a pretty name to the first value in the pack.
1387  if (!getSizes().empty()) {
1388  setNameFn(getSizes().front(), "sizes");
1389  setNameFn(getStrides().front(), "strides");
1390  }
1391 }
1392 
1393 /// Helper function to perform the replacement of all constant uses of `values`
1394 /// by a materialized constant extracted from `maybeConstants`.
1395 /// `values` and `maybeConstants` are expected to have the same size.
1396 template <typename Container>
1397 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1398  Container values,
1399  ArrayRef<OpFoldResult> maybeConstants) {
1400  assert(values.size() == maybeConstants.size() &&
1401  " expected values and maybeConstants of the same size");
1402  bool atLeastOneReplacement = false;
1403  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1404  // Don't materialize a constant if there are no uses: this would indice
1405  // infinite loops in the driver.
1406  if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1407  continue;
1408  assert(isa<Attribute>(maybeConstant) &&
1409  "The constified value should be either unchanged (i.e., == result) "
1410  "or a constant");
1411  Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1412  loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1413  for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1414  // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1415  // yet.
1416  op->replaceUsesOfWith(result, constantVal);
1417  atLeastOneReplacement = true;
1418  }
1419  }
1420  return atLeastOneReplacement;
1421 }
1422 
1423 LogicalResult
1424 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1425  SmallVectorImpl<OpFoldResult> &results) {
1426  OpBuilder builder(*this);
1427 
1428  bool atLeastOneReplacement = replaceConstantUsesOf(
1429  builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1430  getConstifiedMixedOffset());
1431  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1432  getConstifiedMixedSizes());
1433  atLeastOneReplacement |= replaceConstantUsesOf(
1434  builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1435 
1436  return success(atLeastOneReplacement);
1437 }
1438 
1439 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1440  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1441  constifyIndexValues(values, getSource().getType(), getContext(),
1442  getConstantSizes, ShapedType::isDynamic);
1443  return values;
1444 }
1445 
1447 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1448  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1449  constifyIndexValues(values, getSource().getType(), getContext(),
1450  getConstantStrides, ShapedType::isDynamic);
1451  return values;
1452 }
1453 
1454 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1455  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1456  SmallVector<OpFoldResult> values(1, offsetOfr);
1457  constifyIndexValues(values, getSource().getType(), getContext(),
1458  getConstantOffset, ShapedType::isDynamic);
1459  return values[0];
1460 }
1461 
1462 //===----------------------------------------------------------------------===//
1463 // GenericAtomicRMWOp
1464 //===----------------------------------------------------------------------===//
1465 
1466 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1467  Value memref, ValueRange ivs) {
1468  OpBuilder::InsertionGuard g(builder);
1469  result.addOperands(memref);
1470  result.addOperands(ivs);
1471 
1472  if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1473  Type elementType = memrefType.getElementType();
1474  result.addTypes(elementType);
1475 
1476  Region *bodyRegion = result.addRegion();
1477  builder.createBlock(bodyRegion);
1478  bodyRegion->addArgument(elementType, memref.getLoc());
1479  }
1480 }
1481 
1482 LogicalResult GenericAtomicRMWOp::verify() {
1483  auto &body = getRegion();
1484  if (body.getNumArguments() != 1)
1485  return emitOpError("expected single number of entry block arguments");
1486 
1487  if (getResult().getType() != body.getArgument(0).getType())
1488  return emitOpError("expected block argument of the same type result type");
1489 
1490  bool hasSideEffects =
1491  body.walk([&](Operation *nestedOp) {
1492  if (isMemoryEffectFree(nestedOp))
1493  return WalkResult::advance();
1494  nestedOp->emitError(
1495  "body of 'memref.generic_atomic_rmw' should contain "
1496  "only operations with no side effects");
1497  return WalkResult::interrupt();
1498  })
1499  .wasInterrupted();
1500  return hasSideEffects ? failure() : success();
1501 }
1502 
1503 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1504  OperationState &result) {
1506  Type memrefType;
1508 
1509  Type indexType = parser.getBuilder().getIndexType();
1510  if (parser.parseOperand(memref) ||
1512  parser.parseColonType(memrefType) ||
1513  parser.resolveOperand(memref, memrefType, result.operands) ||
1514  parser.resolveOperands(ivs, indexType, result.operands))
1515  return failure();
1516 
1517  Region *body = result.addRegion();
1518  if (parser.parseRegion(*body, {}) ||
1519  parser.parseOptionalAttrDict(result.attributes))
1520  return failure();
1521  result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1522  return success();
1523 }
1524 
1526  p << ' ' << getMemref() << "[" << getIndices()
1527  << "] : " << getMemref().getType() << ' ';
1528  p.printRegion(getRegion());
1529  p.printOptionalAttrDict((*this)->getAttrs());
1530 }
1531 
1532 //===----------------------------------------------------------------------===//
1533 // AtomicYieldOp
1534 //===----------------------------------------------------------------------===//
1535 
1536 LogicalResult AtomicYieldOp::verify() {
1537  Type parentType = (*this)->getParentOp()->getResultTypes().front();
1538  Type resultType = getResult().getType();
1539  if (parentType != resultType)
1540  return emitOpError() << "types mismatch between yield op: " << resultType
1541  << " and its parent: " << parentType;
1542  return success();
1543 }
1544 
1545 //===----------------------------------------------------------------------===//
1546 // GlobalOp
1547 //===----------------------------------------------------------------------===//
1548 
1550  TypeAttr type,
1551  Attribute initialValue) {
1552  p << type;
1553  if (!op.isExternal()) {
1554  p << " = ";
1555  if (op.isUninitialized())
1556  p << "uninitialized";
1557  else
1558  p.printAttributeWithoutType(initialValue);
1559  }
1560 }
1561 
1562 static ParseResult
1564  Attribute &initialValue) {
1565  Type type;
1566  if (parser.parseType(type))
1567  return failure();
1568 
1569  auto memrefType = llvm::dyn_cast<MemRefType>(type);
1570  if (!memrefType || !memrefType.hasStaticShape())
1571  return parser.emitError(parser.getNameLoc())
1572  << "type should be static shaped memref, but got " << type;
1573  typeAttr = TypeAttr::get(type);
1574 
1575  if (parser.parseOptionalEqual())
1576  return success();
1577 
1578  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1579  initialValue = UnitAttr::get(parser.getContext());
1580  return success();
1581  }
1582 
1583  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1584  if (parser.parseAttribute(initialValue, tensorType))
1585  return failure();
1586  if (!llvm::isa<ElementsAttr>(initialValue))
1587  return parser.emitError(parser.getNameLoc())
1588  << "initial value should be a unit or elements attribute";
1589  return success();
1590 }
1591 
1592 LogicalResult GlobalOp::verify() {
1593  auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1594  if (!memrefType || !memrefType.hasStaticShape())
1595  return emitOpError("type should be static shaped memref, but got ")
1596  << getType();
1597 
1598  // Verify that the initial value, if present, is either a unit attribute or
1599  // an elements attribute.
1600  if (getInitialValue().has_value()) {
1601  Attribute initValue = getInitialValue().value();
1602  if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1603  return emitOpError("initial value should be a unit or elements "
1604  "attribute, but got ")
1605  << initValue;
1606 
1607  // Check that the type of the initial value is compatible with the type of
1608  // the global variable.
1609  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1610  Type initType = elementsAttr.getType();
1611  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1612  if (initType != tensorType)
1613  return emitOpError("initial value expected to be of type ")
1614  << tensorType << ", but was of type " << initType;
1615  }
1616  }
1617 
1618  if (std::optional<uint64_t> alignAttr = getAlignment()) {
1619  uint64_t alignment = *alignAttr;
1620 
1621  if (!llvm::isPowerOf2_64(alignment))
1622  return emitError() << "alignment attribute value " << alignment
1623  << " is not a power of 2";
1624  }
1625 
1626  // TODO: verify visibility for declarations.
1627  return success();
1628 }
1629 
1630 ElementsAttr GlobalOp::getConstantInitValue() {
1631  auto initVal = getInitialValue();
1632  if (getConstant() && initVal.has_value())
1633  return llvm::cast<ElementsAttr>(initVal.value());
1634  return {};
1635 }
1636 
1637 //===----------------------------------------------------------------------===//
1638 // GetGlobalOp
1639 //===----------------------------------------------------------------------===//
1640 
1641 LogicalResult
1642 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1643  // Verify that the result type is same as the type of the referenced
1644  // memref.global op.
1645  auto global =
1646  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1647  if (!global)
1648  return emitOpError("'")
1649  << getName() << "' does not reference a valid global memref";
1650 
1651  Type resultType = getResult().getType();
1652  if (global.getType() != resultType)
1653  return emitOpError("result type ")
1654  << resultType << " does not match type " << global.getType()
1655  << " of the global memref @" << getName();
1656  return success();
1657 }
1658 
1659 //===----------------------------------------------------------------------===//
1660 // LoadOp
1661 //===----------------------------------------------------------------------===//
1662 
1663 LogicalResult LoadOp::verify() {
1664  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1665  return emitOpError("incorrect number of indices for load, expected ")
1666  << getMemRefType().getRank() << " but got " << getIndices().size();
1667  }
1668  return success();
1669 }
1670 
1671 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1672  /// load(memrefcast) -> load
1673  if (succeeded(foldMemRefCast(*this)))
1674  return getResult();
1675  return OpFoldResult();
1676 }
1677 
1678 //===----------------------------------------------------------------------===//
1679 // MemorySpaceCastOp
1680 //===----------------------------------------------------------------------===//
1681 
1682 void MemorySpaceCastOp::getAsmResultNames(
1683  function_ref<void(Value, StringRef)> setNameFn) {
1684  setNameFn(getResult(), "memspacecast");
1685 }
1686 
1687 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1688  if (inputs.size() != 1 || outputs.size() != 1)
1689  return false;
1690  Type a = inputs.front(), b = outputs.front();
1691  auto aT = llvm::dyn_cast<MemRefType>(a);
1692  auto bT = llvm::dyn_cast<MemRefType>(b);
1693 
1694  auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1695  auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1696 
1697  if (aT && bT) {
1698  if (aT.getElementType() != bT.getElementType())
1699  return false;
1700  if (aT.getLayout() != bT.getLayout())
1701  return false;
1702  if (aT.getShape() != bT.getShape())
1703  return false;
1704  return true;
1705  }
1706  if (uaT && ubT) {
1707  return uaT.getElementType() == ubT.getElementType();
1708  }
1709  return false;
1710 }
1711 
1712 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1713  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1714  // t2)
1715  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1716  getSourceMutable().assign(parentCast.getSource());
1717  return getResult();
1718  }
1719  return Value{};
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 // PrefetchOp
1724 //===----------------------------------------------------------------------===//
1725 
1727  p << " " << getMemref() << '[';
1729  p << ']' << ", " << (getIsWrite() ? "write" : "read");
1730  p << ", locality<" << getLocalityHint();
1731  p << ">, " << (getIsDataCache() ? "data" : "instr");
1733  (*this)->getAttrs(),
1734  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1735  p << " : " << getMemRefType();
1736 }
1737 
1738 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1739  OpAsmParser::UnresolvedOperand memrefInfo;
1741  IntegerAttr localityHint;
1742  MemRefType type;
1743  StringRef readOrWrite, cacheType;
1744 
1745  auto indexTy = parser.getBuilder().getIndexType();
1746  auto i32Type = parser.getBuilder().getIntegerType(32);
1747  if (parser.parseOperand(memrefInfo) ||
1748  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1749  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1750  parser.parseComma() || parser.parseKeyword("locality") ||
1751  parser.parseLess() ||
1752  parser.parseAttribute(localityHint, i32Type, "localityHint",
1753  result.attributes) ||
1754  parser.parseGreater() || parser.parseComma() ||
1755  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1756  parser.resolveOperand(memrefInfo, type, result.operands) ||
1757  parser.resolveOperands(indexInfo, indexTy, result.operands))
1758  return failure();
1759 
1760  if (readOrWrite != "read" && readOrWrite != "write")
1761  return parser.emitError(parser.getNameLoc(),
1762  "rw specifier has to be 'read' or 'write'");
1763  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1764  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1765 
1766  if (cacheType != "data" && cacheType != "instr")
1767  return parser.emitError(parser.getNameLoc(),
1768  "cache type has to be 'data' or 'instr'");
1769 
1770  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1771  parser.getBuilder().getBoolAttr(cacheType == "data"));
1772 
1773  return success();
1774 }
1775 
1776 LogicalResult PrefetchOp::verify() {
1777  if (getNumOperands() != 1 + getMemRefType().getRank())
1778  return emitOpError("too few indices");
1779 
1780  return success();
1781 }
1782 
1783 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1784  SmallVectorImpl<OpFoldResult> &results) {
1785  // prefetch(memrefcast) -> prefetch
1786  return foldMemRefCast(*this);
1787 }
1788 
1789 //===----------------------------------------------------------------------===//
1790 // RankOp
1791 //===----------------------------------------------------------------------===//
1792 
1793 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1794  // Constant fold rank when the rank of the operand is known.
1795  auto type = getOperand().getType();
1796  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1797  if (shapedType && shapedType.hasRank())
1798  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1799  return IntegerAttr();
1800 }
1801 
1802 //===----------------------------------------------------------------------===//
1803 // ReinterpretCastOp
1804 //===----------------------------------------------------------------------===//
1805 
1806 void ReinterpretCastOp::getAsmResultNames(
1807  function_ref<void(Value, StringRef)> setNameFn) {
1808  setNameFn(getResult(), "reinterpret_cast");
1809 }
1810 
1811 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1812 /// `staticSizes` and `staticStrides` are automatically filled with
1813 /// source-memref-rank sentinel values that encode dynamic entries.
1814 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1815  MemRefType resultType, Value source,
1816  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1817  ArrayRef<OpFoldResult> strides,
1818  ArrayRef<NamedAttribute> attrs) {
1819  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1820  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1821  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1822  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1823  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1824  result.addAttributes(attrs);
1825  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1826  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1827  b.getDenseI64ArrayAttr(staticSizes),
1828  b.getDenseI64ArrayAttr(staticStrides));
1829 }
1830 
1831 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1832  Value source, OpFoldResult offset,
1833  ArrayRef<OpFoldResult> sizes,
1834  ArrayRef<OpFoldResult> strides,
1835  ArrayRef<NamedAttribute> attrs) {
1836  auto sourceType = cast<BaseMemRefType>(source.getType());
1837  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1838  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1839  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1840  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1841  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1842  auto stridedLayout = StridedLayoutAttr::get(
1843  b.getContext(), staticOffsets.front(), staticStrides);
1844  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1845  stridedLayout, sourceType.getMemorySpace());
1846  build(b, result, resultType, source, offset, sizes, strides, attrs);
1847 }
1848 
1849 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1850  MemRefType resultType, Value source,
1851  int64_t offset, ArrayRef<int64_t> sizes,
1852  ArrayRef<int64_t> strides,
1853  ArrayRef<NamedAttribute> attrs) {
1854  SmallVector<OpFoldResult> sizeValues =
1855  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1856  return b.getI64IntegerAttr(v);
1857  }));
1858  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1859  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1860  return b.getI64IntegerAttr(v);
1861  }));
1862  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1863  strideValues, attrs);
1864 }
1865 
1866 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1867  MemRefType resultType, Value source, Value offset,
1868  ValueRange sizes, ValueRange strides,
1869  ArrayRef<NamedAttribute> attrs) {
1870  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1871  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1872  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1873  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1874  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1875 }
1876 
1877 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1878 // completed automatically, like we have for subview and extract_slice.
1879 LogicalResult ReinterpretCastOp::verify() {
1880  // The source and result memrefs should be in the same memory space.
1881  auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1882  auto resultType = llvm::cast<MemRefType>(getType());
1883  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1884  return emitError("different memory spaces specified for source type ")
1885  << srcType << " and result memref type " << resultType;
1886  if (srcType.getElementType() != resultType.getElementType())
1887  return emitError("different element types specified for source type ")
1888  << srcType << " and result memref type " << resultType;
1889 
1890  // Match sizes in result memref type and in static_sizes attribute.
1891  for (auto [idx, resultSize, expectedSize] :
1892  llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1893  if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
1894  return emitError("expected result type with size = ")
1895  << (ShapedType::isDynamic(expectedSize)
1896  ? std::string("dynamic")
1897  : std::to_string(expectedSize))
1898  << " instead of " << resultSize << " in dim = " << idx;
1899  }
1900 
1901  // Match offset and strides in static_offset and static_strides attributes. If
1902  // result memref type has no affine map specified, this will assume an
1903  // identity layout.
1904  int64_t resultOffset;
1905  SmallVector<int64_t, 4> resultStrides;
1906  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1907  return emitError("expected result type to have strided layout but found ")
1908  << resultType;
1909 
1910  // Match offset in result memref type and in static_offsets attribute.
1911  int64_t expectedOffset = getStaticOffsets().front();
1912  if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
1913  return emitError("expected result type with offset = ")
1914  << (ShapedType::isDynamic(expectedOffset)
1915  ? std::string("dynamic")
1916  : std::to_string(expectedOffset))
1917  << " instead of " << resultOffset;
1918 
1919  // Match strides in result memref type and in static_strides attribute.
1920  for (auto [idx, resultStride, expectedStride] :
1921  llvm::enumerate(resultStrides, getStaticStrides())) {
1922  if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
1923  return emitError("expected result type with stride = ")
1924  << (ShapedType::isDynamic(expectedStride)
1925  ? std::string("dynamic")
1926  : std::to_string(expectedStride))
1927  << " instead of " << resultStride << " in dim = " << idx;
1928  }
1929 
1930  return success();
1931 }
1932 
1933 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1934  Value src = getSource();
1935  auto getPrevSrc = [&]() -> Value {
1936  // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1937  if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1938  return prev.getSource();
1939 
1940  // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1941  if (auto prev = src.getDefiningOp<CastOp>())
1942  return prev.getSource();
1943 
1944  // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1945  // are 0.
1946  if (auto prev = src.getDefiningOp<SubViewOp>())
1947  if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1948  return isConstantIntValue(val, 0);
1949  }))
1950  return prev.getSource();
1951 
1952  return nullptr;
1953  };
1954 
1955  if (auto prevSrc = getPrevSrc()) {
1956  getSourceMutable().assign(prevSrc);
1957  return getResult();
1958  }
1959 
1960  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1961  if (!ShapedType::isDynamicShape(getType().getShape()) &&
1962  src.getType() == getType() && getStaticOffsets().front() == 0) {
1963  return src;
1964  }
1965 
1966  return nullptr;
1967 }
1968 
1969 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1972  ShapedType::isDynamic);
1973  return values;
1974 }
1975 
1976 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1977  SmallVector<OpFoldResult> values = getMixedStrides();
1979  ShapedType::isDynamic);
1980  return values;
1981 }
1982 
1983 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1984  SmallVector<OpFoldResult> values = getMixedOffsets();
1985  assert(values.size() == 1 &&
1986  "reinterpret_cast must have one and only one offset");
1988  ShapedType::isDynamic);
1989  return values[0];
1990 }
1991 
1992 namespace {
1993 /// Replace the sequence:
1994 /// ```
1995 /// base, offset, sizes, strides = extract_strided_metadata src
1996 /// dst = reinterpret_cast base to offset, sizes, strides
1997 /// ```
1998 /// With
1999 ///
2000 /// ```
2001 /// dst = memref.cast src
2002 /// ```
2003 ///
2004 /// Note: The cast operation is only inserted when the type of dst and src
2005 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
2006 ///
2007 /// This pattern also matches when the offset, sizes, and strides don't come
2008 /// directly from the `extract_strided_metadata`'s results but it can be
2009 /// statically proven that they would hold the same values.
2010 ///
2011 /// For instance, the following sequence would be replaced:
2012 /// ```
2013 /// base, offset, sizes, strides =
2014 /// extract_strided_metadata memref : memref<3x4xty>
2015 /// dst = reinterpret_cast base to 0, [3, 4], strides
2016 /// ```
2017 /// Because we know (thanks to the type of the input memref) that variable
2018 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
2019 ///
2020 /// Similarly, the following sequence would be replaced:
2021 /// ```
2022 /// c0 = arith.constant 0
2023 /// c4 = arith.constant 4
2024 /// base, offset, sizes, strides =
2025 /// extract_strided_metadata memref : memref<3x4xty>
2026 /// dst = reinterpret_cast base to c0, [3, c4], strides
2027 /// ```
2028 /// Because we know that `offset`and `c0` will hold 0
2029 /// and `c4` will hold 4.
2030 struct ReinterpretCastOpExtractStridedMetadataFolder
2031  : public OpRewritePattern<ReinterpretCastOp> {
2032 public:
2034 
2035  LogicalResult matchAndRewrite(ReinterpretCastOp op,
2036  PatternRewriter &rewriter) const override {
2037  auto extractStridedMetadata =
2038  op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2039  if (!extractStridedMetadata)
2040  return failure();
2041  // Check if the reinterpret cast reconstructs a memref with the exact same
2042  // properties as the extract strided metadata.
2043 
2044  // First, check that the strides are the same.
2045  SmallVector<OpFoldResult> extractStridesOfr =
2046  extractStridedMetadata.getConstifiedMixedStrides();
2047  SmallVector<OpFoldResult> reinterpretStridesOfr =
2048  op.getConstifiedMixedStrides();
2049  if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2050  return failure();
2051 
2052  unsigned rank = op.getType().getRank();
2053  for (unsigned i = 0; i < rank; ++i) {
2054  if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2055  return failure();
2056  }
2057 
2058  // Second, check the sizes.
2059  assert(extractStridedMetadata.getSizes().size() ==
2060  op.getMixedSizes().size() &&
2061  "Strides and sizes rank must match");
2062  SmallVector<OpFoldResult> extractSizesOfr =
2063  extractStridedMetadata.getConstifiedMixedSizes();
2064  SmallVector<OpFoldResult> reinterpretSizesOfr =
2065  op.getConstifiedMixedSizes();
2066  for (unsigned i = 0; i < rank; ++i) {
2067  if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2068  return failure();
2069  }
2070  // Finally, check the offset.
2071  assert(op.getMixedOffsets().size() == 1 &&
2072  "reinterpret_cast with more than one offset should have been "
2073  "rejected by the verifier");
2074  OpFoldResult extractOffsetOfr =
2075  extractStridedMetadata.getConstifiedMixedOffset();
2076  OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2077  if (extractOffsetOfr != reinterpretOffsetOfr)
2078  return failure();
2079 
2080  // At this point, we know that the back and forth between extract strided
2081  // metadata and reinterpret cast is a noop. However, the final type of the
2082  // reinterpret cast may not be exactly the same as the original memref.
2083  // E.g., it could be changing a dimension from static to dynamic. Check that
2084  // here and add a cast if necessary.
2085  Type srcTy = extractStridedMetadata.getSource().getType();
2086  if (srcTy == op.getResult().getType())
2087  rewriter.replaceOp(op, extractStridedMetadata.getSource());
2088  else
2089  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2090  extractStridedMetadata.getSource());
2091 
2092  return success();
2093  }
2094 };
2095 } // namespace
2096 
2097 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2098  MLIRContext *context) {
2099  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2100 }
2101 
2102 //===----------------------------------------------------------------------===//
2103 // Reassociative reshape ops
2104 //===----------------------------------------------------------------------===//
2105 
2106 void CollapseShapeOp::getAsmResultNames(
2107  function_ref<void(Value, StringRef)> setNameFn) {
2108  setNameFn(getResult(), "collapse_shape");
2109 }
2110 
2111 void ExpandShapeOp::getAsmResultNames(
2112  function_ref<void(Value, StringRef)> setNameFn) {
2113  setNameFn(getResult(), "expand_shape");
2114 }
2115 
2116 LogicalResult ExpandShapeOp::reifyResultShapes(
2117  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2118  reifiedResultShapes = {
2119  getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2120  return success();
2121 }
2122 
2123 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2124 /// result and operand. Layout maps are verified separately.
2125 ///
2126 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2127 /// allowed in a reassocation group.
2128 static LogicalResult
2130  ArrayRef<int64_t> expandedShape,
2131  ArrayRef<ReassociationIndices> reassociation,
2132  bool allowMultipleDynamicDimsPerGroup) {
2133  // There must be one reassociation group per collapsed dimension.
2134  if (collapsedShape.size() != reassociation.size())
2135  return op->emitOpError("invalid number of reassociation groups: found ")
2136  << reassociation.size() << ", expected " << collapsedShape.size();
2137 
2138  // The next expected expanded dimension index (while iterating over
2139  // reassociation indices).
2140  int64_t nextDim = 0;
2141  for (const auto &it : llvm::enumerate(reassociation)) {
2142  ReassociationIndices group = it.value();
2143  int64_t collapsedDim = it.index();
2144 
2145  bool foundDynamic = false;
2146  for (int64_t expandedDim : group) {
2147  if (expandedDim != nextDim++)
2148  return op->emitOpError("reassociation indices must be contiguous");
2149 
2150  if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2151  return op->emitOpError("reassociation index ")
2152  << expandedDim << " is out of bounds";
2153 
2154  // Check if there are multiple dynamic dims in a reassociation group.
2155  if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2156  if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2157  return op->emitOpError(
2158  "at most one dimension in a reassociation group may be dynamic");
2159  foundDynamic = true;
2160  }
2161  }
2162 
2163  // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2164  if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2165  return op->emitOpError("collapsed dim (")
2166  << collapsedDim
2167  << ") must be dynamic if and only if reassociation group is "
2168  "dynamic";
2169 
2170  // If all dims in the reassociation group are static, the size of the
2171  // collapsed dim can be verified.
2172  if (!foundDynamic) {
2173  int64_t groupSize = 1;
2174  for (int64_t expandedDim : group)
2175  groupSize *= expandedShape[expandedDim];
2176  if (groupSize != collapsedShape[collapsedDim])
2177  return op->emitOpError("collapsed dim size (")
2178  << collapsedShape[collapsedDim]
2179  << ") must equal reassociation group size (" << groupSize << ")";
2180  }
2181  }
2182 
2183  if (collapsedShape.empty()) {
2184  // Rank 0: All expanded dimensions must be 1.
2185  for (int64_t d : expandedShape)
2186  if (d != 1)
2187  return op->emitOpError(
2188  "rank 0 memrefs can only be extended/collapsed with/from ones");
2189  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2190  // Rank >= 1: Number of dimensions among all reassociation groups must match
2191  // the result memref rank.
2192  return op->emitOpError("expanded rank (")
2193  << expandedShape.size()
2194  << ") inconsistent with number of reassociation indices (" << nextDim
2195  << ")";
2196  }
2197 
2198  return success();
2199 }
2200 
2201 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2202  return getSymbolLessAffineMaps(getReassociationExprs());
2203 }
2204 
2205 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2207  getReassociationIndices());
2208 }
2209 
2210 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2211  return getSymbolLessAffineMaps(getReassociationExprs());
2212 }
2213 
2214 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2216  getReassociationIndices());
2217 }
2218 
2219 /// Compute the layout map after expanding a given source MemRef type with the
2220 /// specified reassociation indices.
2221 static FailureOr<StridedLayoutAttr>
2222 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2223  ArrayRef<ReassociationIndices> reassociation) {
2224  int64_t srcOffset;
2225  SmallVector<int64_t> srcStrides;
2226  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2227  return failure();
2228  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2229 
2230  // 1-1 mapping between srcStrides and reassociation packs.
2231  // Each srcStride starts with the given value and gets expanded according to
2232  // the proper entries in resultShape.
2233  // Example:
2234  // srcStrides = [10000, 1 , 100 ],
2235  // reassociations = [ [0], [1], [2, 3, 4]],
2236  // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2237  // -> For the purpose of stride calculation, the useful sizes are:
2238  // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2239  // resultStrides = [10000, 1, 600, 200, 100]
2240  // Note that a stride does not get expanded along the first entry of each
2241  // shape pack.
2242  SmallVector<int64_t> reverseResultStrides;
2243  reverseResultStrides.reserve(resultShape.size());
2244  unsigned shapeIndex = resultShape.size() - 1;
2245  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2246  ReassociationIndices reassoc = std::get<0>(it);
2247  int64_t currentStrideToExpand = std::get<1>(it);
2248  for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2249  reverseResultStrides.push_back(currentStrideToExpand);
2250  currentStrideToExpand =
2251  (SaturatedInteger::wrap(currentStrideToExpand) *
2252  SaturatedInteger::wrap(resultShape[shapeIndex--]))
2253  .asInteger();
2254  }
2255  }
2256  auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2257  resultStrides.resize(resultShape.size(), 1);
2258  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2259 }
2260 
2261 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2262  MemRefType srcType, ArrayRef<int64_t> resultShape,
2263  ArrayRef<ReassociationIndices> reassociation) {
2264  if (srcType.getLayout().isIdentity()) {
2265  // If the source is contiguous (i.e., no layout map specified), so is the
2266  // result.
2267  MemRefLayoutAttrInterface layout;
2268  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2269  srcType.getMemorySpace());
2270  }
2271 
2272  // Source may not be contiguous. Compute the layout map.
2273  FailureOr<StridedLayoutAttr> computedLayout =
2274  computeExpandedLayoutMap(srcType, resultShape, reassociation);
2275  if (failed(computedLayout))
2276  return failure();
2277  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2278  srcType.getMemorySpace());
2279 }
2280 
2281 FailureOr<SmallVector<OpFoldResult>>
2282 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2283  MemRefType expandedType,
2284  ArrayRef<ReassociationIndices> reassociation,
2285  ArrayRef<OpFoldResult> inputShape) {
2286  std::optional<SmallVector<OpFoldResult>> outputShape =
2287  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2288  inputShape);
2289  if (!outputShape)
2290  return failure();
2291  return *outputShape;
2292 }
2293 
2294 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2295  Type resultType, Value src,
2296  ArrayRef<ReassociationIndices> reassociation,
2297  ArrayRef<OpFoldResult> outputShape) {
2298  auto [staticOutputShape, dynamicOutputShape] =
2300  build(builder, result, llvm::cast<MemRefType>(resultType), src,
2301  getReassociationIndicesAttribute(builder, reassociation),
2302  dynamicOutputShape, staticOutputShape);
2303 }
2304 
2305 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2306  Type resultType, Value src,
2307  ArrayRef<ReassociationIndices> reassociation) {
2308  SmallVector<OpFoldResult> inputShape =
2309  getMixedSizes(builder, result.location, src);
2310  MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2311  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2312  builder, result.location, memrefResultTy, reassociation, inputShape);
2313  // Failure of this assertion usually indicates presence of multiple
2314  // dynamic dimensions in the same reassociation group.
2315  assert(succeeded(outputShape) && "unable to infer output shape");
2316  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2317 }
2318 
2319 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2320  ArrayRef<int64_t> resultShape, Value src,
2321  ArrayRef<ReassociationIndices> reassociation) {
2322  // Only ranked memref source values are supported.
2323  auto srcType = llvm::cast<MemRefType>(src.getType());
2324  FailureOr<MemRefType> resultType =
2325  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2326  // Failure of this assertion usually indicates a problem with the source
2327  // type, e.g., could not get strides/offset.
2328  assert(succeeded(resultType) && "could not compute layout");
2329  build(builder, result, *resultType, src, reassociation);
2330 }
2331 
2332 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2333  ArrayRef<int64_t> resultShape, Value src,
2334  ArrayRef<ReassociationIndices> reassociation,
2335  ArrayRef<OpFoldResult> outputShape) {
2336  // Only ranked memref source values are supported.
2337  auto srcType = llvm::cast<MemRefType>(src.getType());
2338  FailureOr<MemRefType> resultType =
2339  ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2340  // Failure of this assertion usually indicates a problem with the source
2341  // type, e.g., could not get strides/offset.
2342  assert(succeeded(resultType) && "could not compute layout");
2343  build(builder, result, *resultType, src, reassociation, outputShape);
2344 }
2345 
2346 LogicalResult ExpandShapeOp::verify() {
2347  MemRefType srcType = getSrcType();
2348  MemRefType resultType = getResultType();
2349 
2350  if (srcType.getRank() > resultType.getRank()) {
2351  auto r0 = srcType.getRank();
2352  auto r1 = resultType.getRank();
2353  return emitOpError("has source rank ")
2354  << r0 << " and result rank " << r1 << ". This is not an expansion ("
2355  << r0 << " > " << r1 << ").";
2356  }
2357 
2358  // Verify result shape.
2359  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2360  resultType.getShape(),
2361  getReassociationIndices(),
2362  /*allowMultipleDynamicDimsPerGroup=*/true)))
2363  return failure();
2364 
2365  // Compute expected result type (including layout map).
2366  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2367  srcType, resultType.getShape(), getReassociationIndices());
2368  if (failed(expectedResultType))
2369  return emitOpError("invalid source layout map");
2370 
2371  // Check actual result type.
2372  if (*expectedResultType != resultType)
2373  return emitOpError("expected expanded type to be ")
2374  << *expectedResultType << " but found " << resultType;
2375 
2376  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2377  return emitOpError("expected number of static shape bounds to be equal to "
2378  "the output rank (")
2379  << resultType.getRank() << ") but found "
2380  << getStaticOutputShape().size() << " inputs instead";
2381 
2382  if ((int64_t)getOutputShape().size() !=
2383  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2384  return emitOpError("mismatch in dynamic dims in output_shape and "
2385  "static_output_shape: static_output_shape has ")
2386  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2387  << " dynamic dims while output_shape has " << getOutputShape().size()
2388  << " values";
2389 
2390  // Verify if provided output shapes are in agreement with output type.
2391  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2392  ArrayRef<int64_t> resShape = getResult().getType().getShape();
2393  for (auto [pos, shape] : llvm::enumerate(resShape)) {
2394  if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2395  return emitOpError("invalid output shape provided at pos ") << pos;
2396  }
2397  }
2398 
2399  return success();
2400 }
2401 
2402 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2403  MLIRContext *context) {
2404  results.add<
2407 }
2408 
2409 /// Compute the layout map after collapsing a given source MemRef type with the
2410 /// specified reassociation indices.
2411 ///
2412 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2413 /// not possible to check this by inspecting a MemRefType in the general case.
2414 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2415 /// be valid (and thus accepted by this function) unless `strict = true`.
2416 static FailureOr<StridedLayoutAttr>
2417 computeCollapsedLayoutMap(MemRefType srcType,
2418  ArrayRef<ReassociationIndices> reassociation,
2419  bool strict = false) {
2420  int64_t srcOffset;
2421  SmallVector<int64_t> srcStrides;
2422  auto srcShape = srcType.getShape();
2423  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2424  return failure();
2425 
2426  // The result stride of a reassociation group is the stride of the last entry
2427  // of the reassociation. (TODO: Should be the minimum stride in the
2428  // reassociation because strides are not necessarily sorted. E.g., when using
2429  // memref.transpose.) Dimensions of size 1 should be skipped, because their
2430  // strides are meaningless and could have any arbitrary value.
2431  SmallVector<int64_t> resultStrides;
2432  resultStrides.reserve(reassociation.size());
2433  for (const ReassociationIndices &reassoc : reassociation) {
2434  ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2435  while (srcShape[ref.back()] == 1 && ref.size() > 1)
2436  ref = ref.drop_back();
2437  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2438  resultStrides.push_back(srcStrides[ref.back()]);
2439  } else {
2440  // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2441  // the corresponding stride may have to be skipped. (See above comment.)
2442  // Therefore, the result stride cannot be statically determined and must
2443  // be dynamic.
2444  resultStrides.push_back(ShapedType::kDynamic);
2445  }
2446  }
2447 
2448  // Validate that each reassociation group is contiguous.
2449  unsigned resultStrideIndex = resultStrides.size() - 1;
2450  for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2451  auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2452  auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2453  for (int64_t idx : llvm::reverse(trailingReassocs)) {
2454  stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2455 
2456  // Both source and result stride must have the same static value. In that
2457  // case, we can be sure, that the dimensions are collapsible (because they
2458  // are contiguous).
2459  // If `strict = false` (default during op verification), we accept cases
2460  // where one or both strides are dynamic. This is best effort: We reject
2461  // ops where obviously non-contiguous dims are collapsed, but accept ops
2462  // where we cannot be sure statically. Such ops may fail at runtime. See
2463  // the op documentation for details.
2464  auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2465  if (strict && (stride.saturated || srcStride.saturated))
2466  return failure();
2467 
2468  // Dimensions of size 1 should be skipped, because their strides are
2469  // meaningless and could have any arbitrary value.
2470  if (srcShape[idx - 1] == 1)
2471  continue;
2472 
2473  if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2474  return failure();
2475  }
2476  }
2477  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2478 }
2479 
2480 bool CollapseShapeOp::isGuaranteedCollapsible(
2481  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2482  // MemRefs with identity layout are always collapsible.
2483  if (srcType.getLayout().isIdentity())
2484  return true;
2485 
2486  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2487  /*strict=*/true));
2488 }
2489 
2490 MemRefType CollapseShapeOp::computeCollapsedType(
2491  MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2492  SmallVector<int64_t> resultShape;
2493  resultShape.reserve(reassociation.size());
2494  for (const ReassociationIndices &group : reassociation) {
2495  auto groupSize = SaturatedInteger::wrap(1);
2496  for (int64_t srcDim : group)
2497  groupSize =
2498  groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2499  resultShape.push_back(groupSize.asInteger());
2500  }
2501 
2502  if (srcType.getLayout().isIdentity()) {
2503  // If the source is contiguous (i.e., no layout map specified), so is the
2504  // result.
2505  MemRefLayoutAttrInterface layout;
2506  return MemRefType::get(resultShape, srcType.getElementType(), layout,
2507  srcType.getMemorySpace());
2508  }
2509 
2510  // Source may not be fully contiguous. Compute the layout map.
2511  // Note: Dimensions that are collapsed into a single dim are assumed to be
2512  // contiguous.
2513  FailureOr<StridedLayoutAttr> computedLayout =
2514  computeCollapsedLayoutMap(srcType, reassociation);
2515  assert(succeeded(computedLayout) &&
2516  "invalid source layout map or collapsing non-contiguous dims");
2517  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2518  srcType.getMemorySpace());
2519 }
2520 
2521 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2522  ArrayRef<ReassociationIndices> reassociation,
2523  ArrayRef<NamedAttribute> attrs) {
2524  auto srcType = llvm::cast<MemRefType>(src.getType());
2525  MemRefType resultType =
2526  CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2528  getReassociationIndicesAttribute(b, reassociation));
2529  build(b, result, resultType, src, attrs);
2530 }
2531 
2532 LogicalResult CollapseShapeOp::verify() {
2533  MemRefType srcType = getSrcType();
2534  MemRefType resultType = getResultType();
2535 
2536  if (srcType.getRank() < resultType.getRank()) {
2537  auto r0 = srcType.getRank();
2538  auto r1 = resultType.getRank();
2539  return emitOpError("has source rank ")
2540  << r0 << " and result rank " << r1 << ". This is not a collapse ("
2541  << r0 << " < " << r1 << ").";
2542  }
2543 
2544  // Verify result shape.
2545  if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2546  srcType.getShape(), getReassociationIndices(),
2547  /*allowMultipleDynamicDimsPerGroup=*/true)))
2548  return failure();
2549 
2550  // Compute expected result type (including layout map).
2551  MemRefType expectedResultType;
2552  if (srcType.getLayout().isIdentity()) {
2553  // If the source is contiguous (i.e., no layout map specified), so is the
2554  // result.
2555  MemRefLayoutAttrInterface layout;
2556  expectedResultType =
2557  MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2558  srcType.getMemorySpace());
2559  } else {
2560  // Source may not be fully contiguous. Compute the layout map.
2561  // Note: Dimensions that are collapsed into a single dim are assumed to be
2562  // contiguous.
2563  FailureOr<StridedLayoutAttr> computedLayout =
2564  computeCollapsedLayoutMap(srcType, getReassociationIndices());
2565  if (failed(computedLayout))
2566  return emitOpError(
2567  "invalid source layout map or collapsing non-contiguous dims");
2568  expectedResultType =
2569  MemRefType::get(resultType.getShape(), srcType.getElementType(),
2570  *computedLayout, srcType.getMemorySpace());
2571  }
2572 
2573  if (expectedResultType != resultType)
2574  return emitOpError("expected collapsed type to be ")
2575  << expectedResultType << " but found " << resultType;
2576 
2577  return success();
2578 }
2579 
2581  : public OpRewritePattern<CollapseShapeOp> {
2582 public:
2584 
2585  LogicalResult matchAndRewrite(CollapseShapeOp op,
2586  PatternRewriter &rewriter) const override {
2587  auto cast = op.getOperand().getDefiningOp<CastOp>();
2588  if (!cast)
2589  return failure();
2590 
2591  if (!CastOp::canFoldIntoConsumerOp(cast))
2592  return failure();
2593 
2594  Type newResultType = CollapseShapeOp::computeCollapsedType(
2595  llvm::cast<MemRefType>(cast.getOperand().getType()),
2596  op.getReassociationIndices());
2597 
2598  if (newResultType == op.getResultType()) {
2599  rewriter.modifyOpInPlace(
2600  op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2601  } else {
2602  Value newOp = rewriter.create<CollapseShapeOp>(
2603  op->getLoc(), cast.getSource(), op.getReassociationIndices());
2604  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2605  }
2606  return success();
2607  }
2608 };
2609 
2610 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2611  MLIRContext *context) {
2612  results.add<
2614  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2615  memref::DimOp, MemRefType>,
2617 }
2618 
2619 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2620  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2621  adaptor.getOperands());
2622 }
2623 
2624 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2625  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2626  adaptor.getOperands());
2627 }
2628 
2629 //===----------------------------------------------------------------------===//
2630 // ReshapeOp
2631 //===----------------------------------------------------------------------===//
2632 
2633 void ReshapeOp::getAsmResultNames(
2634  function_ref<void(Value, StringRef)> setNameFn) {
2635  setNameFn(getResult(), "reshape");
2636 }
2637 
2638 LogicalResult ReshapeOp::verify() {
2639  Type operandType = getSource().getType();
2640  Type resultType = getResult().getType();
2641 
2642  Type operandElementType =
2643  llvm::cast<ShapedType>(operandType).getElementType();
2644  Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2645  if (operandElementType != resultElementType)
2646  return emitOpError("element types of source and destination memref "
2647  "types should be the same");
2648 
2649  if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2650  if (!operandMemRefType.getLayout().isIdentity())
2651  return emitOpError("source memref type should have identity affine map");
2652 
2653  int64_t shapeSize =
2654  llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2655  auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2656  if (resultMemRefType) {
2657  if (!resultMemRefType.getLayout().isIdentity())
2658  return emitOpError("result memref type should have identity affine map");
2659  if (shapeSize == ShapedType::kDynamic)
2660  return emitOpError("cannot use shape operand with dynamic length to "
2661  "reshape to statically-ranked memref type");
2662  if (shapeSize != resultMemRefType.getRank())
2663  return emitOpError(
2664  "length of shape operand differs from the result's memref rank");
2665  }
2666  return success();
2667 }
2668 
2669 //===----------------------------------------------------------------------===//
2670 // StoreOp
2671 //===----------------------------------------------------------------------===//
2672 
2673 LogicalResult StoreOp::verify() {
2674  if (getNumOperands() != 2 + getMemRefType().getRank())
2675  return emitOpError("store index operand count not equal to memref rank");
2676 
2677  return success();
2678 }
2679 
2680 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2681  SmallVectorImpl<OpFoldResult> &results) {
2682  /// store(memrefcast) -> store
2683  return foldMemRefCast(*this, getValueToStore());
2684 }
2685 
2686 //===----------------------------------------------------------------------===//
2687 // SubViewOp
2688 //===----------------------------------------------------------------------===//
2689 
2690 void SubViewOp::getAsmResultNames(
2691  function_ref<void(Value, StringRef)> setNameFn) {
2692  setNameFn(getResult(), "subview");
2693 }
2694 
2695 /// A subview result type can be fully inferred from the source type and the
2696 /// static representation of offsets, sizes and strides. Special sentinels
2697 /// encode the dynamic case.
2698 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2699  ArrayRef<int64_t> staticOffsets,
2700  ArrayRef<int64_t> staticSizes,
2701  ArrayRef<int64_t> staticStrides) {
2702  unsigned rank = sourceMemRefType.getRank();
2703  (void)rank;
2704  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2705  assert(staticSizes.size() == rank && "staticSizes length mismatch");
2706  assert(staticStrides.size() == rank && "staticStrides length mismatch");
2707 
2708  // Extract source offset and strides.
2709  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
2710 
2711  // Compute target offset whose value is:
2712  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2713  int64_t targetOffset = sourceOffset;
2714  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2715  auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2716  targetOffset = (SaturatedInteger::wrap(targetOffset) +
2717  SaturatedInteger::wrap(staticOffset) *
2718  SaturatedInteger::wrap(sourceStride))
2719  .asInteger();
2720  }
2721 
2722  // Compute target stride whose value is:
2723  // `sourceStrides_i * staticStrides_i`.
2724  SmallVector<int64_t, 4> targetStrides;
2725  targetStrides.reserve(staticOffsets.size());
2726  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2727  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2728  targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2729  SaturatedInteger::wrap(staticStride))
2730  .asInteger());
2731  }
2732 
2733  // The type is now known.
2734  return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2735  StridedLayoutAttr::get(sourceMemRefType.getContext(),
2736  targetOffset, targetStrides),
2737  sourceMemRefType.getMemorySpace());
2738 }
2739 
2740 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2741  ArrayRef<OpFoldResult> offsets,
2742  ArrayRef<OpFoldResult> sizes,
2743  ArrayRef<OpFoldResult> strides) {
2744  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2745  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2746  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2747  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2748  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2749  if (!hasValidSizesOffsets(staticOffsets))
2750  return {};
2751  if (!hasValidSizesOffsets(staticSizes))
2752  return {};
2753  if (!hasValidStrides(staticStrides))
2754  return {};
2755  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2756  staticSizes, staticStrides);
2757 }
2758 
2759 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2760  MemRefType sourceRankedTensorType,
2761  ArrayRef<int64_t> offsets,
2762  ArrayRef<int64_t> sizes,
2763  ArrayRef<int64_t> strides) {
2764  auto inferredType = llvm::cast<MemRefType>(
2765  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2766  assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2767  "expected ");
2768  if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2769  return inferredType;
2770 
2771  // Compute which dimensions are dropped.
2772  std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2773  computeRankReductionMask(inferredType.getShape(), resultShape);
2774  assert(dimsToProject.has_value() && "invalid rank reduction");
2775 
2776  // Compute the layout and result type.
2777  auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2778  SmallVector<int64_t> rankReducedStrides;
2779  rankReducedStrides.reserve(resultShape.size());
2780  for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2781  if (!dimsToProject->contains(idx))
2782  rankReducedStrides.push_back(value);
2783  }
2784  return MemRefType::get(resultShape, inferredType.getElementType(),
2785  StridedLayoutAttr::get(inferredLayout.getContext(),
2786  inferredLayout.getOffset(),
2787  rankReducedStrides),
2788  inferredType.getMemorySpace());
2789 }
2790 
2791 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2792  MemRefType sourceRankedTensorType,
2793  ArrayRef<OpFoldResult> offsets,
2794  ArrayRef<OpFoldResult> sizes,
2795  ArrayRef<OpFoldResult> strides) {
2796  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2797  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2798  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2799  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2800  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2801  return SubViewOp::inferRankReducedResultType(
2802  resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2803  staticStrides);
2804 }
2805 
2806 // Build a SubViewOp with mixed static and dynamic entries and custom result
2807 // type. If the type passed is nullptr, it is inferred.
2808 void SubViewOp::build(OpBuilder &b, OperationState &result,
2809  MemRefType resultType, Value source,
2810  ArrayRef<OpFoldResult> offsets,
2811  ArrayRef<OpFoldResult> sizes,
2812  ArrayRef<OpFoldResult> strides,
2813  ArrayRef<NamedAttribute> attrs) {
2814  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2815  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2816  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2817  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2818  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2819  auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2820  // Structuring implementation this way avoids duplication between builders.
2821  if (!resultType) {
2822  resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2823  sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2824  }
2825  result.addAttributes(attrs);
2826  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2827  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2828  b.getDenseI64ArrayAttr(staticSizes),
2829  b.getDenseI64ArrayAttr(staticStrides));
2830 }
2831 
2832 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2833 // type.
2834 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2835  ArrayRef<OpFoldResult> offsets,
2836  ArrayRef<OpFoldResult> sizes,
2837  ArrayRef<OpFoldResult> strides,
2838  ArrayRef<NamedAttribute> attrs) {
2839  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2840 }
2841 
2842 // Build a SubViewOp with static entries and inferred result type.
2843 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2844  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2845  ArrayRef<int64_t> strides,
2846  ArrayRef<NamedAttribute> attrs) {
2847  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2848  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2849  return b.getI64IntegerAttr(v);
2850  }));
2851  SmallVector<OpFoldResult> sizeValues =
2852  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2853  return b.getI64IntegerAttr(v);
2854  }));
2855  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2856  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2857  return b.getI64IntegerAttr(v);
2858  }));
2859  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2860 }
2861 
2862 // Build a SubViewOp with dynamic entries and custom result type. If the
2863 // type passed is nullptr, it is inferred.
2864 void SubViewOp::build(OpBuilder &b, OperationState &result,
2865  MemRefType resultType, Value source,
2866  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2867  ArrayRef<int64_t> strides,
2868  ArrayRef<NamedAttribute> attrs) {
2869  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2870  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2871  return b.getI64IntegerAttr(v);
2872  }));
2873  SmallVector<OpFoldResult> sizeValues =
2874  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2875  return b.getI64IntegerAttr(v);
2876  }));
2877  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2878  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2879  return b.getI64IntegerAttr(v);
2880  }));
2881  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2882  attrs);
2883 }
2884 
2885 // Build a SubViewOp with dynamic entries and custom result type. If the type
2886 // passed is nullptr, it is inferred.
2887 void SubViewOp::build(OpBuilder &b, OperationState &result,
2888  MemRefType resultType, Value source, ValueRange offsets,
2889  ValueRange sizes, ValueRange strides,
2890  ArrayRef<NamedAttribute> attrs) {
2891  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2892  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2893  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2894  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2895  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2896  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2897  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2898 }
2899 
2900 // Build a SubViewOp with dynamic entries and inferred result type.
2901 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2902  ValueRange offsets, ValueRange sizes, ValueRange strides,
2903  ArrayRef<NamedAttribute> attrs) {
2904  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2905 }
2906 
2907 /// For ViewLikeOpInterface.
2908 Value SubViewOp::getViewSource() { return getSource(); }
2909 
2910 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2911 /// static value).
2912 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2913  int64_t t1Offset, t2Offset;
2914  SmallVector<int64_t> t1Strides, t2Strides;
2915  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2916  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2917  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2918 }
2919 
2920 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2921 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2922 /// marked as dropped in `droppedDims`.
2923 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2924  const llvm::SmallBitVector &droppedDims) {
2925  assert(size_t(t1.getRank()) == droppedDims.size() &&
2926  "incorrect number of bits");
2927  assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2928  "incorrect number of dropped dims");
2929  int64_t t1Offset, t2Offset;
2930  SmallVector<int64_t> t1Strides, t2Strides;
2931  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2932  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2933  if (failed(res1) || failed(res2))
2934  return false;
2935  for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2936  if (droppedDims[i])
2937  continue;
2938  if (t1Strides[i] != t2Strides[j])
2939  return false;
2940  ++j;
2941  }
2942  return true;
2943 }
2944 
2946  Operation *op, Type expectedType) {
2947  auto memrefType = llvm::cast<ShapedType>(expectedType);
2948  switch (result) {
2950  return success();
2952  return op->emitError("expected result rank to be smaller or equal to ")
2953  << "the source rank. ";
2955  return op->emitError("expected result type to be ")
2956  << expectedType
2957  << " or a rank-reduced version. (mismatch of result sizes) ";
2959  return op->emitError("expected result element type to be ")
2960  << memrefType.getElementType();
2962  return op->emitError("expected result and source memory spaces to match.");
2964  return op->emitError("expected result type to be ")
2965  << expectedType
2966  << " or a rank-reduced version. (mismatch of result layout) ";
2967  }
2968  llvm_unreachable("unexpected subview verification result");
2969 }
2970 
2971 /// Verifier for SubViewOp.
2972 LogicalResult SubViewOp::verify() {
2973  MemRefType baseType = getSourceType();
2974  MemRefType subViewType = getType();
2975 
2976  // The base memref and the view memref should be in the same memory space.
2977  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2978  return emitError("different memory spaces specified for base memref "
2979  "type ")
2980  << baseType << " and subview memref type " << subViewType;
2981 
2982  // Verify that the base memref type has a strided layout map.
2983  if (!isStrided(baseType))
2984  return emitError("base type ") << baseType << " is not strided";
2985 
2986  // Compute the expected result type, assuming that there are no rank
2987  // reductions.
2988  auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2989  baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2990 
2991  // Verify all properties of a shaped type: rank, element type and dimension
2992  // sizes. This takes into account potential rank reductions.
2993  auto shapedTypeVerification = isRankReducedType(
2994  /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2995  if (shapedTypeVerification != SliceVerificationResult::Success)
2996  return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2997 
2998  // Make sure that the memory space did not change.
2999  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3001  *this, expectedType);
3002 
3003  // Verify the offset of the layout map.
3004  if (!haveCompatibleOffsets(expectedType, subViewType))
3006  *this, expectedType);
3007 
3008  // The only thing that's left to verify now are the strides. First, compute
3009  // the unused dimensions due to rank reductions. We have to look at sizes and
3010  // strides to decide which dimensions were dropped. This function also
3011  // partially verifies strides in case of rank reductions.
3012  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3013  getMixedSizes());
3014  if (failed(unusedDims))
3016  *this, expectedType);
3017 
3018  // Strides must match.
3019  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3021  *this, expectedType);
3022 
3023  return success();
3024 }
3025 
3026 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
3027  return os << "range " << range.offset << ":" << range.size << ":"
3028  << range.stride;
3029 }
3030 
3031 /// Return the list of Range (i.e. offset, size, stride). Each Range
3032 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3033 /// with `b` at location `loc`.
3034 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3035  OpBuilder &b, Location loc) {
3036  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3037  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3038  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3040  unsigned rank = ranks[0];
3041  res.reserve(rank);
3042  for (unsigned idx = 0; idx < rank; ++idx) {
3043  Value offset =
3044  op.isDynamicOffset(idx)
3045  ? op.getDynamicOffset(idx)
3046  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
3047  Value size =
3048  op.isDynamicSize(idx)
3049  ? op.getDynamicSize(idx)
3050  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
3051  Value stride =
3052  op.isDynamicStride(idx)
3053  ? op.getDynamicStride(idx)
3054  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
3055  res.emplace_back(Range{offset, size, stride});
3056  }
3057  return res;
3058 }
3059 
3060 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3061 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3062 /// the rank of the inferred result type if `currentResultType` is lower rank
3063 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3064 /// together with the result type. In this case, it is important to compute
3065 /// the dropped dimensions using `currentSourceType` whose strides align with
3066 /// `currentResultType`.
3068  MemRefType currentResultType, MemRefType currentSourceType,
3069  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3070  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3071  auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
3072  sourceType, mixedOffsets, mixedSizes, mixedStrides));
3073  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3074  currentSourceType, currentResultType, mixedSizes);
3075  if (failed(unusedDims))
3076  return nullptr;
3077 
3078  auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3079  SmallVector<int64_t> shape, strides;
3080  unsigned numDimsAfterReduction =
3081  nonRankReducedType.getRank() - unusedDims->count();
3082  shape.reserve(numDimsAfterReduction);
3083  strides.reserve(numDimsAfterReduction);
3084  for (const auto &[idx, size, stride] :
3085  llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3086  nonRankReducedType.getShape(), layout.getStrides())) {
3087  if (unusedDims->test(idx))
3088  continue;
3089  shape.push_back(size);
3090  strides.push_back(stride);
3091  }
3092 
3093  return MemRefType::get(shape, nonRankReducedType.getElementType(),
3094  StridedLayoutAttr::get(sourceType.getContext(),
3095  layout.getOffset(), strides),
3096  nonRankReducedType.getMemorySpace());
3097 }
3098 
3100  OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3101  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3102  unsigned rank = memrefType.getRank();
3103  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3104  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3105  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3106  auto targetType =
3107  llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
3108  targetShape, memrefType, offsets, sizes, strides));
3109  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3110  sizes, strides);
3111 }
3112 
3113 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3114  Value value,
3115  ArrayRef<int64_t> desiredShape) {
3116  auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3117  assert(sourceMemrefType && "not a ranked memref type");
3118  auto sourceShape = sourceMemrefType.getShape();
3119  if (sourceShape.equals(desiredShape))
3120  return value;
3121  auto maybeRankReductionMask =
3122  mlir::computeRankReductionMask(sourceShape, desiredShape);
3123  if (!maybeRankReductionMask)
3124  return failure();
3125  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3126 }
3127 
3128 /// Helper method to check if a `subview` operation is trivially a no-op. This
3129 /// is the case if the all offsets are zero, all strides are 1, and the source
3130 /// shape is same as the size of the subview. In such cases, the subview can
3131 /// be folded into its source.
3132 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3133  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3134  return false;
3135 
3136  auto mixedOffsets = subViewOp.getMixedOffsets();
3137  auto mixedSizes = subViewOp.getMixedSizes();
3138  auto mixedStrides = subViewOp.getMixedStrides();
3139 
3140  // Check offsets are zero.
3141  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3142  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3143  return !intValue || intValue.value() != 0;
3144  }))
3145  return false;
3146 
3147  // Check strides are one.
3148  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3149  std::optional<int64_t> intValue = getConstantIntValue(ofr);
3150  return !intValue || intValue.value() != 1;
3151  }))
3152  return false;
3153 
3154  // Check all size values are static and matches the (static) source shape.
3155  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3156  for (const auto &size : llvm::enumerate(mixedSizes)) {
3157  std::optional<int64_t> intValue = getConstantIntValue(size.value());
3158  if (!intValue || *intValue != sourceShape[size.index()])
3159  return false;
3160  }
3161  // All conditions met. The `SubViewOp` is foldable as a no-op.
3162  return true;
3163 }
3164 
3165 namespace {
3166 /// Pattern to rewrite a subview op with MemRefCast arguments.
3167 /// This essentially pushes memref.cast past its consuming subview when
3168 /// `canFoldIntoConsumerOp` is true.
3169 ///
3170 /// Example:
3171 /// ```
3172 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3173 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3174 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3175 /// ```
3176 /// is rewritten into:
3177 /// ```
3178 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3179 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3180 /// memref<3x4xf32, strided<[?, 1], offset: ?>>
3181 /// ```
3182 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3183 public:
3185 
3186  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3187  PatternRewriter &rewriter) const override {
3188  // Any constant operand, just return to let SubViewOpConstantFolder kick
3189  // in.
3190  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3191  return matchPattern(operand, matchConstantIndex());
3192  }))
3193  return failure();
3194 
3195  auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3196  if (!castOp)
3197  return failure();
3198 
3199  if (!CastOp::canFoldIntoConsumerOp(castOp))
3200  return failure();
3201 
3202  // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3203  // the MemRefCastOp source operand type to infer the result type and the
3204  // current SubViewOp source operand type to compute the dropped dimensions
3205  // if the operation is rank-reducing.
3206  auto resultType = getCanonicalSubViewResultType(
3207  subViewOp.getType(), subViewOp.getSourceType(),
3208  llvm::cast<MemRefType>(castOp.getSource().getType()),
3209  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3210  subViewOp.getMixedStrides());
3211  if (!resultType)
3212  return failure();
3213 
3214  Value newSubView = rewriter.create<SubViewOp>(
3215  subViewOp.getLoc(), resultType, castOp.getSource(),
3216  subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3217  subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3218  subViewOp.getStaticStrides());
3219  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3220  newSubView);
3221  return success();
3222  }
3223 };
3224 
3225 /// Canonicalize subview ops that are no-ops. When the source shape is not
3226 /// same as a result shape due to use of `affine_map`.
3227 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3228 public:
3230 
3231  LogicalResult matchAndRewrite(SubViewOp subViewOp,
3232  PatternRewriter &rewriter) const override {
3233  if (!isTrivialSubViewOp(subViewOp))
3234  return failure();
3235  if (subViewOp.getSourceType() == subViewOp.getType()) {
3236  rewriter.replaceOp(subViewOp, subViewOp.getSource());
3237  return success();
3238  }
3239  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3240  subViewOp.getSource());
3241  return success();
3242  }
3243 };
3244 } // namespace
3245 
3246 /// Return the canonical type of the result of a subview.
3248  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3249  ArrayRef<OpFoldResult> mixedSizes,
3250  ArrayRef<OpFoldResult> mixedStrides) {
3251  // Infer a memref type without taking into account any rank reductions.
3252  auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3253  mixedSizes, mixedStrides);
3254  if (!resTy)
3255  return {};
3256  MemRefType nonReducedType = cast<MemRefType>(resTy);
3257 
3258  // Directly return the non-rank reduced type if there are no dropped dims.
3259  llvm::SmallBitVector droppedDims = op.getDroppedDims();
3260  if (droppedDims.none())
3261  return nonReducedType;
3262 
3263  // Take the strides and offset from the non-rank reduced type.
3264  auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
3265 
3266  // Drop dims from shape and strides.
3267  SmallVector<int64_t> targetShape;
3268  SmallVector<int64_t> targetStrides;
3269  for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3270  if (droppedDims.test(i))
3271  continue;
3272  targetStrides.push_back(nonReducedStrides[i]);
3273  targetShape.push_back(nonReducedType.getDimSize(i));
3274  }
3275 
3276  return MemRefType::get(targetShape, nonReducedType.getElementType(),
3277  StridedLayoutAttr::get(nonReducedType.getContext(),
3278  offset, targetStrides),
3279  nonReducedType.getMemorySpace());
3280  }
3281 };
3282 
3283 /// A canonicalizer wrapper to replace SubViewOps.
3285  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3286  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3287  }
3288 };
3289 
3290 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3291  MLIRContext *context) {
3292  results
3295  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3296 }
3297 
3298 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3299  MemRefType sourceMemrefType = getSource().getType();
3300  MemRefType resultMemrefType = getResult().getType();
3301  auto resultLayout =
3302  dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3303 
3304  if (resultMemrefType == sourceMemrefType &&
3305  resultMemrefType.hasStaticShape() &&
3306  (!resultLayout || resultLayout.hasStaticLayout())) {
3307  return getViewSource();
3308  }
3309 
3310  // Fold subview(subview(x)), where both subviews have the same size and the
3311  // second subview's offsets are all zero. (I.e., the second subview is a
3312  // no-op.)
3313  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3314  auto srcSizes = srcSubview.getMixedSizes();
3315  auto sizes = getMixedSizes();
3316  auto offsets = getMixedOffsets();
3317  bool allOffsetsZero = llvm::all_of(
3318  offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3319  auto strides = getMixedStrides();
3320  bool allStridesOne = llvm::all_of(
3321  strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3322  bool allSizesSame = llvm::equal(sizes, srcSizes);
3323  if (allOffsetsZero && allStridesOne && allSizesSame &&
3324  resultMemrefType == sourceMemrefType)
3325  return getViewSource();
3326  }
3327 
3328  return {};
3329 }
3330 
3331 //===----------------------------------------------------------------------===//
3332 // TransposeOp
3333 //===----------------------------------------------------------------------===//
3334 
3335 void TransposeOp::getAsmResultNames(
3336  function_ref<void(Value, StringRef)> setNameFn) {
3337  setNameFn(getResult(), "transpose");
3338 }
3339 
3340 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3341 static MemRefType inferTransposeResultType(MemRefType memRefType,
3342  AffineMap permutationMap) {
3343  auto originalSizes = memRefType.getShape();
3344  auto [originalStrides, offset] = getStridesAndOffset(memRefType);
3345  assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3346 
3347  // Compute permuted sizes and strides.
3348  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3349  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3350 
3351  return MemRefType::Builder(memRefType)
3352  .setShape(sizes)
3353  .setLayout(
3354  StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3355 }
3356 
3357 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3358  AffineMapAttr permutation,
3359  ArrayRef<NamedAttribute> attrs) {
3360  auto permutationMap = permutation.getValue();
3361  assert(permutationMap);
3362 
3363  auto memRefType = llvm::cast<MemRefType>(in.getType());
3364  // Compute result type.
3365  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3366 
3367  result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3368  build(b, result, resultType, in, attrs);
3369 }
3370 
3371 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3373  p << " " << getIn() << " " << getPermutation();
3374  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3375  p << " : " << getIn().getType() << " to " << getType();
3376 }
3377 
3378 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3380  AffineMap permutation;
3381  MemRefType srcType, dstType;
3382  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3383  parser.parseOptionalAttrDict(result.attributes) ||
3384  parser.parseColonType(srcType) ||
3385  parser.resolveOperand(in, srcType, result.operands) ||
3386  parser.parseKeywordType("to", dstType) ||
3387  parser.addTypeToList(dstType, result.types))
3388  return failure();
3389 
3390  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3391  AffineMapAttr::get(permutation));
3392  return success();
3393 }
3394 
3395 LogicalResult TransposeOp::verify() {
3396  if (!getPermutation().isPermutation())
3397  return emitOpError("expected a permutation map");
3398  if (getPermutation().getNumDims() != getIn().getType().getRank())
3399  return emitOpError("expected a permutation map of same rank as the input");
3400 
3401  auto srcType = llvm::cast<MemRefType>(getIn().getType());
3402  auto resultType = llvm::cast<MemRefType>(getType());
3403  auto canonicalResultType = canonicalizeStridedLayout(
3404  inferTransposeResultType(srcType, getPermutation()));
3405 
3406  if (canonicalizeStridedLayout(resultType) != canonicalResultType)
3407  return emitOpError("result type ")
3408  << resultType
3409  << " is not equivalent to the canonical transposed input type "
3410  << canonicalResultType;
3411  return success();
3412 }
3413 
3414 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3415  // First check for identity permutation, we can fold it away if input and
3416  // result types are identical already.
3417  if (getPermutation().isIdentity() && getType() == getIn().getType())
3418  return getIn();
3419  // Fold two consecutive memref.transpose Ops into one by composing their
3420  // permutation maps.
3421  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3422  AffineMap composedPermutation =
3423  getPermutation().compose(otherTransposeOp.getPermutation());
3424  getInMutable().assign(otherTransposeOp.getIn());
3425  setPermutation(composedPermutation);
3426  return getResult();
3427  }
3428  return {};
3429 }
3430 
3431 //===----------------------------------------------------------------------===//
3432 // ViewOp
3433 //===----------------------------------------------------------------------===//
3434 
3435 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3436  setNameFn(getResult(), "view");
3437 }
3438 
3439 LogicalResult ViewOp::verify() {
3440  auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3441  auto viewType = getType();
3442 
3443  // The base memref should have identity layout map (or none).
3444  if (!baseType.getLayout().isIdentity())
3445  return emitError("unsupported map for base memref type ") << baseType;
3446 
3447  // The result memref should have identity layout map (or none).
3448  if (!viewType.getLayout().isIdentity())
3449  return emitError("unsupported map for result memref type ") << viewType;
3450 
3451  // The base memref and the view memref should be in the same memory space.
3452  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3453  return emitError("different memory spaces specified for base memref "
3454  "type ")
3455  << baseType << " and view memref type " << viewType;
3456 
3457  // Verify that we have the correct number of sizes for the result type.
3458  unsigned numDynamicDims = viewType.getNumDynamicDims();
3459  if (getSizes().size() != numDynamicDims)
3460  return emitError("incorrect number of size operands for type ") << viewType;
3461 
3462  return success();
3463 }
3464 
3465 Value ViewOp::getViewSource() { return getSource(); }
3466 
3467 namespace {
3468 
3469 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3471 
3472  LogicalResult matchAndRewrite(ViewOp viewOp,
3473  PatternRewriter &rewriter) const override {
3474  // Return if none of the operands are constants.
3475  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3476  return matchPattern(operand, matchConstantIndex());
3477  }))
3478  return failure();
3479 
3480  // Get result memref type.
3481  auto memrefType = viewOp.getType();
3482 
3483  // Get offset from old memref view type 'memRefType'.
3484  int64_t oldOffset;
3485  SmallVector<int64_t, 4> oldStrides;
3486  if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3487  return failure();
3488  assert(oldOffset == 0 && "Expected 0 offset");
3489 
3490  SmallVector<Value, 4> newOperands;
3491 
3492  // Offset cannot be folded into result type.
3493 
3494  // Fold any dynamic dim operands which are produced by a constant.
3495  SmallVector<int64_t, 4> newShapeConstants;
3496  newShapeConstants.reserve(memrefType.getRank());
3497 
3498  unsigned dynamicDimPos = 0;
3499  unsigned rank = memrefType.getRank();
3500  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3501  int64_t dimSize = memrefType.getDimSize(dim);
3502  // If this is already static dimension, keep it.
3503  if (!ShapedType::isDynamic(dimSize)) {
3504  newShapeConstants.push_back(dimSize);
3505  continue;
3506  }
3507  auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3508  if (auto constantIndexOp =
3509  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3510  // Dynamic shape dimension will be folded.
3511  newShapeConstants.push_back(constantIndexOp.value());
3512  } else {
3513  // Dynamic shape dimension not folded; copy operand from old memref.
3514  newShapeConstants.push_back(dimSize);
3515  newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3516  }
3517  dynamicDimPos++;
3518  }
3519 
3520  // Create new memref type with constant folded dims.
3521  MemRefType newMemRefType =
3522  MemRefType::Builder(memrefType).setShape(newShapeConstants);
3523  // Nothing new, don't fold.
3524  if (newMemRefType == memrefType)
3525  return failure();
3526 
3527  // Create new ViewOp.
3528  auto newViewOp = rewriter.create<ViewOp>(
3529  viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3530  viewOp.getByteShift(), newOperands);
3531  // Insert a cast so we have the same type as the old memref type.
3532  rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3533  return success();
3534  }
3535 };
3536 
3537 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3539 
3540  LogicalResult matchAndRewrite(ViewOp viewOp,
3541  PatternRewriter &rewriter) const override {
3542  Value memrefOperand = viewOp.getOperand(0);
3543  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3544  if (!memrefCastOp)
3545  return failure();
3546  Value allocOperand = memrefCastOp.getOperand();
3547  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3548  if (!allocOp)
3549  return failure();
3550  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3551  viewOp.getByteShift(),
3552  viewOp.getSizes());
3553  return success();
3554  }
3555 };
3556 
3557 } // namespace
3558 
3559 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3560  MLIRContext *context) {
3561  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3562 }
3563 
3564 //===----------------------------------------------------------------------===//
3565 // AtomicRMWOp
3566 //===----------------------------------------------------------------------===//
3567 
3568 LogicalResult AtomicRMWOp::verify() {
3569  if (getMemRefType().getRank() != getNumOperands() - 2)
3570  return emitOpError(
3571  "expects the number of subscripts to be equal to memref rank");
3572  switch (getKind()) {
3573  case arith::AtomicRMWKind::addf:
3574  case arith::AtomicRMWKind::maximumf:
3575  case arith::AtomicRMWKind::minimumf:
3576  case arith::AtomicRMWKind::mulf:
3577  if (!llvm::isa<FloatType>(getValue().getType()))
3578  return emitOpError() << "with kind '"
3579  << arith::stringifyAtomicRMWKind(getKind())
3580  << "' expects a floating-point type";
3581  break;
3582  case arith::AtomicRMWKind::addi:
3583  case arith::AtomicRMWKind::maxs:
3584  case arith::AtomicRMWKind::maxu:
3585  case arith::AtomicRMWKind::mins:
3586  case arith::AtomicRMWKind::minu:
3587  case arith::AtomicRMWKind::muli:
3588  case arith::AtomicRMWKind::ori:
3589  case arith::AtomicRMWKind::andi:
3590  if (!llvm::isa<IntegerType>(getValue().getType()))
3591  return emitOpError() << "with kind '"
3592  << arith::stringifyAtomicRMWKind(getKind())
3593  << "' expects an integer type";
3594  break;
3595  default:
3596  break;
3597  }
3598  return success();
3599 }
3600 
3601 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3602  /// atomicrmw(memrefcast) -> atomicrmw
3603  if (succeeded(foldMemRefCast(*this, getValue())))
3604  return getResult();
3605  return OpFoldResult();
3606 }
3607 
3608 //===----------------------------------------------------------------------===//
3609 // TableGen'd op method definitions
3610 //===----------------------------------------------------------------------===//
3611 
3612 #define GET_OP_CLASSES
3613 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
Definition: MemRefOps.cpp:162
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
Definition: MemRefOps.cpp:115
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1549
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
Definition: MemRefOps.cpp:2129
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
Definition: MemRefOps.cpp:448
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
Definition: MemRefOps.cpp:3341
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
Definition: MemRefOps.cpp:429
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:2912
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
Definition: MemRefOps.cpp:1397
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:3067
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
Definition: MemRefOps.cpp:2945
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
Definition: MemRefOps.cpp:175
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1563
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
Definition: MemRefOps.cpp:154
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:937
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:3132
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
Definition: MemRefOps.cpp:471
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
Definition: MemRefOps.cpp:2923
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
Definition: MemRefOps.cpp:922
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
Definition: MemRefOps.cpp:2222
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
Definition: MemRefOps.cpp:2417
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:200
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:136
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:149
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:95
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:213
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:234
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:224
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:227
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:59
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition: MemRefOps.cpp:67
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Definition: MemRefOps.cpp:3099
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:318
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:387
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3034
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Move allocations into an allocation scope, if it is legal to move them (e.g.
Definition: MemRefOps.cpp:518
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:521
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
Definition: MemRefOps.cpp:478
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:481
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:2585
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:3284
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:3285
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:3247
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:3248
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.