MLIR  14.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 
25 using namespace mlir;
26 using namespace mlir::memref;
27 
28 /// Materialize a single constant operation from a given attribute value with
29 /// the desired resultant type.
31  Attribute value, Type type,
32  Location loc) {
33  if (arith::ConstantOp::isBuildableWith(value, type))
34  return builder.create<arith::ConstantOp>(loc, value, type);
35  if (ConstantOp::isBuildableWith(value, type))
36  return builder.create<ConstantOp>(loc, value, type);
37  return nullptr;
38 }
39 
40 //===----------------------------------------------------------------------===//
41 // Common canonicalization pattern support logic
42 //===----------------------------------------------------------------------===//
43 
44 /// This is a common class used for patterns of the form
45 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
46 /// into the root operation directly.
48  bool folded = false;
49  for (OpOperand &operand : op->getOpOperands()) {
50  auto cast = operand.get().getDefiningOp<CastOp>();
51  if (cast && operand.get() != inner &&
52  !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
53  operand.set(cast.getOperand());
54  folded = true;
55  }
56  }
57  return success(folded);
58 }
59 
60 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
61 /// type.
63  if (auto memref = type.dyn_cast<MemRefType>())
64  return RankedTensorType::get(memref.getShape(), memref.getElementType());
65  if (auto memref = type.dyn_cast<UnrankedMemRefType>())
66  return UnrankedTensorType::get(memref.getElementType());
67  return NoneType::get(type.getContext());
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // AllocOp / AllocaOp
72 //===----------------------------------------------------------------------===//
73 
74 template <typename AllocLikeOp>
75 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
77  "applies to only alloc or alloca");
78  auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
79  if (!memRefType)
80  return op.emitOpError("result must be a memref");
81 
82  if (static_cast<int64_t>(op.dynamicSizes().size()) !=
83  memRefType.getNumDynamicDims())
84  return op.emitOpError("dimension operand count does not equal memref "
85  "dynamic dimension count");
86 
87  unsigned numSymbols = 0;
88  if (!memRefType.getLayout().isIdentity())
89  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
90  if (op.symbolOperands().size() != numSymbols)
91  return op.emitOpError("symbol operand count does not equal memref symbol "
92  "count: expected ")
93  << numSymbols << ", got " << op.symbolOperands().size();
94 
95  return success();
96 }
97 
98 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
99 
100 static LogicalResult verify(AllocaOp op) {
101  // An alloca op needs to have an ancestor with an allocation scope trait.
102  if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
103  return op.emitOpError(
104  "requires an ancestor op with AutomaticAllocationScope trait");
105 
106  return verifyAllocLikeOp(op);
107 }
108 
109 namespace {
110 /// Fold constant dimensions into an alloc like operation.
111 template <typename AllocLikeOp>
112 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
114 
115  LogicalResult matchAndRewrite(AllocLikeOp alloc,
116  PatternRewriter &rewriter) const override {
117  // Check to see if any dimensions operands are constants. If so, we can
118  // substitute and drop them.
119  if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
120  return matchPattern(operand, matchConstantIndex());
121  }))
122  return failure();
123 
124  auto memrefType = alloc.getType();
125 
126  // Ok, we have one or more constant operands. Collect the non-constant ones
127  // and keep track of the resultant memref type to build.
128  SmallVector<int64_t, 4> newShapeConstants;
129  newShapeConstants.reserve(memrefType.getRank());
130  SmallVector<Value, 4> dynamicSizes;
131 
132  unsigned dynamicDimPos = 0;
133  for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
134  int64_t dimSize = memrefType.getDimSize(dim);
135  // If this is already static dimension, keep it.
136  if (dimSize != -1) {
137  newShapeConstants.push_back(dimSize);
138  continue;
139  }
140  auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
141  auto *defOp = dynamicSize.getDefiningOp();
142  if (auto constantIndexOp =
143  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
144  // Dynamic shape dimension will be folded.
145  newShapeConstants.push_back(constantIndexOp.value());
146  } else {
147  // Dynamic shape dimension not folded; copy dynamicSize from old memref.
148  newShapeConstants.push_back(-1);
149  dynamicSizes.push_back(dynamicSize);
150  }
151  dynamicDimPos++;
152  }
153 
154  // Create new memref type (which will have fewer dynamic dimensions).
155  MemRefType newMemRefType =
156  MemRefType::Builder(memrefType).setShape(newShapeConstants);
157  assert(static_cast<int64_t>(dynamicSizes.size()) ==
158  newMemRefType.getNumDynamicDims());
159 
160  // Create and insert the alloc op for the new memref.
161  auto newAlloc = rewriter.create<AllocLikeOp>(
162  alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
163  alloc.alignmentAttr());
164  // Insert a cast so we have the same type as the old alloc.
165  auto resultCast =
166  rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
167 
168  rewriter.replaceOp(alloc, {resultCast});
169  return success();
170  }
171 };
172 
173 /// Fold alloc operations with no users or only store and dealloc uses.
174 template <typename T>
175 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
177 
178  LogicalResult matchAndRewrite(T alloc,
179  PatternRewriter &rewriter) const override {
180  if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
181  if (auto storeOp = dyn_cast<StoreOp>(op))
182  return storeOp.value() == alloc;
183  return !isa<DeallocOp>(op);
184  }))
185  return failure();
186 
187  for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
188  rewriter.eraseOp(user);
189 
190  rewriter.eraseOp(alloc);
191  return success();
192  }
193 };
194 } // namespace
195 
196 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
197  MLIRContext *context) {
198  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
199 }
200 
201 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
202  MLIRContext *context) {
203  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
204  context);
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // AllocaScopeOp
209 //===----------------------------------------------------------------------===//
210 
211 static void print(OpAsmPrinter &p, AllocaScopeOp &op) {
212  bool printBlockTerminators = false;
213 
214  p << ' ';
215  if (!op.results().empty()) {
216  p << " -> (" << op.getResultTypes() << ")";
217  printBlockTerminators = true;
218  }
219  p << ' ';
220  p.printRegion(op.bodyRegion(),
221  /*printEntryBlockArgs=*/false,
222  /*printBlockTerminators=*/printBlockTerminators);
223  p.printOptionalAttrDict(op->getAttrs());
224 }
225 
227  OperationState &result) {
228  // Create a region for the body.
229  result.regions.reserve(1);
230  Region *bodyRegion = result.addRegion();
231 
232  // Parse optional results type list.
233  if (parser.parseOptionalArrowTypeList(result.types))
234  return failure();
235 
236  // Parse the body region.
237  if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
238  return failure();
239  AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
240  result.location);
241 
242  // Parse the optional attribute list.
243  if (parser.parseOptionalAttrDict(result.attributes))
244  return failure();
245 
246  return success();
247 }
248 
249 static LogicalResult verify(AllocaScopeOp op) {
250  if (failed(RegionBranchOpInterface::verifyTypes(op)))
251  return failure();
252 
253  return success();
254 }
255 
256 void AllocaScopeOp::getSuccessorRegions(
257  Optional<unsigned> index, ArrayRef<Attribute> operands,
259  if (index.hasValue()) {
260  regions.push_back(RegionSuccessor(getResults()));
261  return;
262  }
263 
264  regions.push_back(RegionSuccessor(&bodyRegion()));
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // AssumeAlignmentOp
269 //===----------------------------------------------------------------------===//
270 
271 static LogicalResult verify(AssumeAlignmentOp op) {
272  unsigned alignment = op.alignment();
273  if (!llvm::isPowerOf2_32(alignment))
274  return op.emitOpError("alignment must be power of 2");
275  return success();
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // CastOp
280 //===----------------------------------------------------------------------===//
281 
282 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
283 /// source memref. This is useful to to fold a memref.cast into a consuming op
284 /// and implement canonicalization patterns for ops in different dialects that
285 /// may consume the results of memref.cast operations. Such foldable memref.cast
286 /// operations are typically inserted as `view` and `subview` ops are
287 /// canonicalized, to preserve the type compatibility of their uses.
288 ///
289 /// Returns true when all conditions are met:
290 /// 1. source and result are ranked memrefs with strided semantics and same
291 /// element type and rank.
292 /// 2. each of the source's size, offset or stride has more static information
293 /// than the corresponding result's size, offset or stride.
294 ///
295 /// Example 1:
296 /// ```mlir
297 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
298 /// %2 = consumer %1 ... : memref<?x?xf32> ...
299 /// ```
300 ///
301 /// may fold into:
302 ///
303 /// ```mlir
304 /// %2 = consumer %0 ... : memref<8x16xf32> ...
305 /// ```
306 ///
307 /// Example 2:
308 /// ```
309 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
310 /// to memref<?x?xf32>
311 /// consumer %1 : memref<?x?xf32> ...
312 /// ```
313 ///
314 /// may fold into:
315 ///
316 /// ```
317 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
318 /// ```
319 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
320  MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
321  MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
322 
323  // Requires ranked MemRefType.
324  if (!sourceType || !resultType)
325  return false;
326 
327  // Requires same elemental type.
328  if (sourceType.getElementType() != resultType.getElementType())
329  return false;
330 
331  // Requires same rank.
332  if (sourceType.getRank() != resultType.getRank())
333  return false;
334 
335  // Only fold casts between strided memref forms.
336  int64_t sourceOffset, resultOffset;
337  SmallVector<int64_t, 4> sourceStrides, resultStrides;
338  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
339  failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
340  return false;
341 
342  // If cast is towards more static sizes along any dimension, don't fold.
343  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
344  auto ss = std::get<0>(it), st = std::get<1>(it);
345  if (ss != st)
346  if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
347  return false;
348  }
349 
350  // If cast is towards more static offset along any dimension, don't fold.
351  if (sourceOffset != resultOffset)
352  if (ShapedType::isDynamicStrideOrOffset(sourceOffset) &&
353  !ShapedType::isDynamicStrideOrOffset(resultOffset))
354  return false;
355 
356  // If cast is towards more static strides along any dimension, don't fold.
357  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
358  auto ss = std::get<0>(it), st = std::get<1>(it);
359  if (ss != st)
360  if (ShapedType::isDynamicStrideOrOffset(ss) &&
361  !ShapedType::isDynamicStrideOrOffset(st))
362  return false;
363  }
364 
365  return true;
366 }
367 
368 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
369  if (inputs.size() != 1 || outputs.size() != 1)
370  return false;
371  Type a = inputs.front(), b = outputs.front();
372  auto aT = a.dyn_cast<MemRefType>();
373  auto bT = b.dyn_cast<MemRefType>();
374 
375  auto uaT = a.dyn_cast<UnrankedMemRefType>();
376  auto ubT = b.dyn_cast<UnrankedMemRefType>();
377 
378  if (aT && bT) {
379  if (aT.getElementType() != bT.getElementType())
380  return false;
381  if (aT.getLayout() != bT.getLayout()) {
382  int64_t aOffset, bOffset;
383  SmallVector<int64_t, 4> aStrides, bStrides;
384  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
385  failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
386  aStrides.size() != bStrides.size())
387  return false;
388 
389  // Strides along a dimension/offset are compatible if the value in the
390  // source memref is static and the value in the target memref is the
391  // same. They are also compatible if either one is dynamic (see
392  // description of MemRefCastOp for details).
393  auto checkCompatible = [](int64_t a, int64_t b) {
394  return (a == MemRefType::getDynamicStrideOrOffset() ||
395  b == MemRefType::getDynamicStrideOrOffset() || a == b);
396  };
397  if (!checkCompatible(aOffset, bOffset))
398  return false;
399  for (const auto &aStride : enumerate(aStrides))
400  if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
401  return false;
402  }
403  if (aT.getMemorySpace() != bT.getMemorySpace())
404  return false;
405 
406  // They must have the same rank, and any specified dimensions must match.
407  if (aT.getRank() != bT.getRank())
408  return false;
409 
410  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
411  int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
412  if (aDim != -1 && bDim != -1 && aDim != bDim)
413  return false;
414  }
415  return true;
416  } else {
417  if (!aT && !uaT)
418  return false;
419  if (!bT && !ubT)
420  return false;
421  // Unranked to unranked casting is unsupported
422  if (uaT && ubT)
423  return false;
424 
425  auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
426  auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
427  if (aEltType != bEltType)
428  return false;
429 
430  auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
431  auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
432  return aMemSpace == bMemSpace;
433  }
434 
435  return false;
436 }
437 
438 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
439  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // CopyOp
444 //===----------------------------------------------------------------------===//
445 
446 namespace {
447 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
448 /// and element type, the cast can be skipped. Such CastOps only cast the layout
449 /// of the type.
450 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
452 
453  LogicalResult matchAndRewrite(CopyOp copyOp,
454  PatternRewriter &rewriter) const override {
455  bool modified = false;
456 
457  // Check source.
458  if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) {
459  auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
460  auto toType = castOp.source().getType().dyn_cast<MemRefType>();
461 
462  if (fromType && toType) {
463  if (fromType.getShape() == toType.getShape() &&
464  fromType.getElementType() == toType.getElementType()) {
465  rewriter.updateRootInPlace(
466  copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); });
467  modified = true;
468  }
469  }
470  }
471 
472  // Check target.
473  if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) {
474  auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
475  auto toType = castOp.source().getType().dyn_cast<MemRefType>();
476 
477  if (fromType && toType) {
478  if (fromType.getShape() == toType.getShape() &&
479  fromType.getElementType() == toType.getElementType()) {
480  rewriter.updateRootInPlace(
481  copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); });
482  modified = true;
483  }
484  }
485  }
486 
487  return success(modified);
488  }
489 };
490 
491 /// Fold memref.copy(%x, %x).
492 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
494 
495  LogicalResult matchAndRewrite(CopyOp copyOp,
496  PatternRewriter &rewriter) const override {
497  if (copyOp.source() != copyOp.target())
498  return failure();
499 
500  rewriter.eraseOp(copyOp);
501  return success();
502  }
503 };
504 } // namespace
505 
506 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
507  MLIRContext *context) {
508  results.add<FoldCopyOfCast, FoldSelfCopy>(context);
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // DeallocOp
513 //===----------------------------------------------------------------------===//
514 
515 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
517  /// dealloc(memrefcast) -> dealloc
518  return foldMemRefCast(*this);
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // DimOp
523 //===----------------------------------------------------------------------===//
524 
525 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
526  int64_t index) {
527  auto loc = result.location;
528  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
529  build(builder, result, source, indexValue);
530 }
531 
532 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
533  Value index) {
534  auto indexTy = builder.getIndexType();
535  build(builder, result, indexTy, source, index);
536 }
537 
538 Optional<int64_t> DimOp::getConstantIndex() {
539  if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
540  return constantOp.getValue().cast<IntegerAttr>().getInt();
541  return {};
542 }
543 
544 static LogicalResult verify(DimOp op) {
545  // Assume unknown index to be in range.
546  Optional<int64_t> index = op.getConstantIndex();
547  if (!index.hasValue())
548  return success();
549 
550  // Check that constant index is not knowingly out of range.
551  auto type = op.source().getType();
552  if (auto memrefType = type.dyn_cast<MemRefType>()) {
553  if (index.getValue() >= memrefType.getRank())
554  return op.emitOpError("index is out of range");
555  } else if (type.isa<UnrankedMemRefType>()) {
556  // Assume index to be in range.
557  } else {
558  llvm_unreachable("expected operand with memref type");
559  }
560  return success();
561 }
562 
563 /// Return a map with key being elements in `vals` and data being number of
564 /// occurences of it. Use std::map, since the `vals` here are strides and the
565 /// dynamic stride value is the same as the tombstone value for
566 /// `DenseMap<int64_t>`.
567 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
568  std::map<int64_t, unsigned> numOccurences;
569  for (auto val : vals)
570  numOccurences[val]++;
571  return numOccurences;
572 }
573 
574 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
575 /// to be a subset of `originalType` with some `1` entries erased, return the
576 /// set of indices that specifies which of the entries of `originalShape` are
577 /// dropped to obtain `reducedShape`.
578 /// This accounts for cases where there are multiple unit-dims, but only a
579 /// subset of those are dropped. For MemRefTypes these can be disambiguated
580 /// using the strides. If a dimension is dropped the stride must be dropped too.
582 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
583  ArrayRef<OpFoldResult> sizes) {
584  llvm::SmallDenseSet<unsigned> unusedDims;
585  if (originalType.getRank() == reducedType.getRank())
586  return unusedDims;
587 
588  for (const auto &dim : llvm::enumerate(sizes))
589  if (auto attr = dim.value().dyn_cast<Attribute>())
590  if (attr.cast<IntegerAttr>().getInt() == 1)
591  unusedDims.insert(dim.index());
592 
593  SmallVector<int64_t> originalStrides, candidateStrides;
594  int64_t originalOffset, candidateOffset;
595  if (failed(
596  getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
597  failed(
598  getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
599  return llvm::None;
600 
601  // For memrefs, a dimension is truly dropped if its corresponding stride is
602  // also dropped. This is particularly important when more than one of the dims
603  // is 1. Track the number of occurences of the strides in the original type
604  // and the candidate type. For each unused dim that stride should not be
605  // present in the candidate type. Note that there could be multiple dimensions
606  // that have the same size. We dont need to exactly figure out which dim
607  // corresponds to which stride, we just need to verify that the number of
608  // reptitions of a stride in the original + number of unused dims with that
609  // stride == number of repititions of a stride in the candidate.
610  std::map<int64_t, unsigned> currUnaccountedStrides =
611  getNumOccurences(originalStrides);
612  std::map<int64_t, unsigned> candidateStridesNumOccurences =
613  getNumOccurences(candidateStrides);
614  llvm::SmallDenseSet<unsigned> prunedUnusedDims;
615  for (unsigned dim : unusedDims) {
616  int64_t originalStride = originalStrides[dim];
617  if (currUnaccountedStrides[originalStride] >
618  candidateStridesNumOccurences[originalStride]) {
619  // This dim can be treated as dropped.
620  currUnaccountedStrides[originalStride]--;
621  continue;
622  }
623  if (currUnaccountedStrides[originalStride] ==
624  candidateStridesNumOccurences[originalStride]) {
625  // The stride for this is not dropped. Keep as is.
626  prunedUnusedDims.insert(dim);
627  continue;
628  }
629  if (currUnaccountedStrides[originalStride] <
630  candidateStridesNumOccurences[originalStride]) {
631  // This should never happen. Cant have a stride in the reduced rank type
632  // that wasnt in the original one.
633  return llvm::None;
634  }
635  }
636 
637  for (auto prunedDim : prunedUnusedDims)
638  unusedDims.erase(prunedDim);
639  if (unusedDims.size() + reducedType.getRank() != originalType.getRank())
640  return llvm::None;
641  return unusedDims;
642 }
643 
644 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
645  MemRefType sourceType = getSourceType();
646  MemRefType resultType = getType();
648  computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
649  assert(unusedDims && "unable to find unused dims of subview");
650  return *unusedDims;
651 }
652 
653 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
654  // All forms of folding require a known index.
655  auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
656  if (!index)
657  return {};
658 
659  // Folding for unranked types (UnrankedMemRefType) is not supported.
660  auto memrefType = source().getType().dyn_cast<MemRefType>();
661  if (!memrefType)
662  return {};
663 
664  // Fold if the shape extent along the given index is known.
665  if (!memrefType.isDynamicDim(index.getInt())) {
666  Builder builder(getContext());
667  return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
668  }
669 
670  // The size at the given index is now known to be a dynamic size.
671  unsigned unsignedIndex = index.getValue().getZExtValue();
672 
673  // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
674  Operation *definingOp = source().getDefiningOp();
675 
676  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
677  return *(alloc.getDynamicSizes().begin() +
678  memrefType.getDynamicDimIndex(unsignedIndex));
679 
680  if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
681  return *(alloca.getDynamicSizes().begin() +
682  memrefType.getDynamicDimIndex(unsignedIndex));
683 
684  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
685  return *(view.getDynamicSizes().begin() +
686  memrefType.getDynamicDimIndex(unsignedIndex));
687 
688  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
689  llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims();
690  unsigned resultIndex = 0;
691  unsigned sourceRank = subview.getSourceType().getRank();
692  unsigned sourceIndex = 0;
693  for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
694  if (unusedDims.count(i))
695  continue;
696  if (resultIndex == unsignedIndex) {
697  sourceIndex = i;
698  break;
699  }
700  resultIndex++;
701  }
702  assert(subview.isDynamicSize(sourceIndex) &&
703  "expected dynamic subview size");
704  return subview.getDynamicSize(sourceIndex);
705  }
706 
707  if (auto sizeInterface =
708  dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
709  assert(sizeInterface.isDynamicSize(unsignedIndex) &&
710  "Expected dynamic subview size");
711  return sizeInterface.getDynamicSize(unsignedIndex);
712  }
713 
714  // dim(memrefcast) -> dim
715  if (succeeded(foldMemRefCast(*this)))
716  return getResult();
717 
718  return {};
719 }
720 
721 namespace {
722 /// Fold dim of a memref reshape operation to a load into the reshape's shape
723 /// operand.
724 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
726 
727  LogicalResult matchAndRewrite(DimOp dim,
728  PatternRewriter &rewriter) const override {
729  auto reshape = dim.source().getDefiningOp<ReshapeOp>();
730 
731  if (!reshape)
732  return failure();
733 
734  // Place the load directly after the reshape to ensure that the shape memref
735  // was not mutated.
736  rewriter.setInsertionPointAfter(reshape);
737  Location loc = dim.getLoc();
738  Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
739  if (load.getType() != dim.getType())
740  load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
741  rewriter.replaceOp(dim, load);
742  return success();
743  }
744 };
745 
746 } // namespace
747 
748 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
749  MLIRContext *context) {
750  results.add<DimOfMemRefReshape>(context);
751 }
752 
753 // ---------------------------------------------------------------------------
754 // DmaStartOp
755 // ---------------------------------------------------------------------------
756 
757 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
758  Value srcMemRef, ValueRange srcIndices, Value destMemRef,
759  ValueRange destIndices, Value numElements,
760  Value tagMemRef, ValueRange tagIndices, Value stride,
761  Value elementsPerStride) {
762  result.addOperands(srcMemRef);
763  result.addOperands(srcIndices);
764  result.addOperands(destMemRef);
765  result.addOperands(destIndices);
766  result.addOperands({numElements, tagMemRef});
767  result.addOperands(tagIndices);
768  if (stride)
769  result.addOperands({stride, elementsPerStride});
770 }
771 
772 static void print(OpAsmPrinter &p, DmaStartOp op) {
773  p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], "
774  << op.getDstMemRef() << '[' << op.getDstIndices() << "], "
775  << op.getNumElements() << ", " << op.getTagMemRef() << '['
776  << op.getTagIndices() << ']';
777  if (op.isStrided())
778  p << ", " << op.getStride() << ", " << op.getNumElementsPerStride();
779 
780  p.printOptionalAttrDict(op->getAttrs());
781  p << " : " << op.getSrcMemRef().getType() << ", "
782  << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType();
783 }
784 
785 // Parse DmaStartOp.
786 // Ex:
787 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
788 // %tag[%index], %stride, %num_elt_per_stride :
789 // : memref<3076 x f32, 0>,
790 // memref<1024 x f32, 2>,
791 // memref<1 x i32>
792 //
794  OperationState &result) {
795  OpAsmParser::OperandType srcMemRefInfo;
796  SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
797  OpAsmParser::OperandType dstMemRefInfo;
798  SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
799  OpAsmParser::OperandType numElementsInfo;
800  OpAsmParser::OperandType tagMemrefInfo;
801  SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
802  SmallVector<OpAsmParser::OperandType, 2> strideInfo;
803 
804  SmallVector<Type, 3> types;
805  auto indexType = parser.getBuilder().getIndexType();
806 
807  // Parse and resolve the following list of operands:
808  // *) source memref followed by its indices (in square brackets).
809  // *) destination memref followed by its indices (in square brackets).
810  // *) dma size in KiB.
811  if (parser.parseOperand(srcMemRefInfo) ||
812  parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
813  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
814  parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
815  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
816  parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
817  parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
818  return failure();
819 
820  // Parse optional stride and elements per stride.
821  if (parser.parseTrailingOperandList(strideInfo))
822  return failure();
823 
824  bool isStrided = strideInfo.size() == 2;
825  if (!strideInfo.empty() && !isStrided) {
826  return parser.emitError(parser.getNameLoc(),
827  "expected two stride related operands");
828  }
829 
830  if (parser.parseColonTypeList(types))
831  return failure();
832  if (types.size() != 3)
833  return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
834 
835  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
836  parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
837  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
838  parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
839  // size should be an index.
840  parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
841  parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
842  // tag indices should be index.
843  parser.resolveOperands(tagIndexInfos, indexType, result.operands))
844  return failure();
845 
846  if (isStrided) {
847  if (parser.resolveOperands(strideInfo, indexType, result.operands))
848  return failure();
849  }
850 
851  return success();
852 }
853 
854 static LogicalResult verify(DmaStartOp op) {
855  unsigned numOperands = op.getNumOperands();
856 
857  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
858  // the number of elements.
859  if (numOperands < 4)
860  return op.emitOpError("expected at least 4 operands");
861 
862  // Check types of operands. The order of these calls is important: the later
863  // calls rely on some type properties to compute the operand position.
864  // 1. Source memref.
865  if (!op.getSrcMemRef().getType().isa<MemRefType>())
866  return op.emitOpError("expected source to be of memref type");
867  if (numOperands < op.getSrcMemRefRank() + 4)
868  return op.emitOpError()
869  << "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
870  if (!op.getSrcIndices().empty() &&
871  !llvm::all_of(op.getSrcIndices().getTypes(),
872  [](Type t) { return t.isIndex(); }))
873  return op.emitOpError("expected source indices to be of index type");
874 
875  // 2. Destination memref.
876  if (!op.getDstMemRef().getType().isa<MemRefType>())
877  return op.emitOpError("expected destination to be of memref type");
878  unsigned numExpectedOperands =
879  op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
880  if (numOperands < numExpectedOperands)
881  return op.emitOpError()
882  << "expected at least " << numExpectedOperands << " operands";
883  if (!op.getDstIndices().empty() &&
884  !llvm::all_of(op.getDstIndices().getTypes(),
885  [](Type t) { return t.isIndex(); }))
886  return op.emitOpError("expected destination indices to be of index type");
887 
888  // 3. Number of elements.
889  if (!op.getNumElements().getType().isIndex())
890  return op.emitOpError("expected num elements to be of index type");
891 
892  // 4. Tag memref.
893  if (!op.getTagMemRef().getType().isa<MemRefType>())
894  return op.emitOpError("expected tag to be of memref type");
895  numExpectedOperands += op.getTagMemRefRank();
896  if (numOperands < numExpectedOperands)
897  return op.emitOpError()
898  << "expected at least " << numExpectedOperands << " operands";
899  if (!op.getTagIndices().empty() &&
900  !llvm::all_of(op.getTagIndices().getTypes(),
901  [](Type t) { return t.isIndex(); }))
902  return op.emitOpError("expected tag indices to be of index type");
903 
904  // Optional stride-related operands must be either both present or both
905  // absent.
906  if (numOperands != numExpectedOperands &&
907  numOperands != numExpectedOperands + 2)
908  return op.emitOpError("incorrect number of operands");
909 
910  // 5. Strides.
911  if (op.isStrided()) {
912  if (!op.getStride().getType().isIndex() ||
913  !op.getNumElementsPerStride().getType().isIndex())
914  return op.emitOpError(
915  "expected stride and num elements per stride to be of type index");
916  }
917 
918  return success();
919 }
920 
921 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
923  /// dma_start(memrefcast) -> dma_start
924  return foldMemRefCast(*this);
925 }
926 
927 // ---------------------------------------------------------------------------
928 // DmaWaitOp
929 // ---------------------------------------------------------------------------
930 
931 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
933  /// dma_wait(memrefcast) -> dma_wait
934  return foldMemRefCast(*this);
935 }
936 
937 static LogicalResult verify(DmaWaitOp op) {
938  // Check that the number of tag indices matches the tagMemRef rank.
939  unsigned numTagIndices = op.tagIndices().size();
940  unsigned tagMemRefRank = op.getTagMemRefRank();
941  if (numTagIndices != tagMemRefRank)
942  return op.emitOpError() << "expected tagIndices to have the same number of "
943  "elements as the tagMemRef rank, expected "
944  << tagMemRefRank << ", but got " << numTagIndices;
945  return success();
946 }
947 
948 //===----------------------------------------------------------------------===//
949 // GenericAtomicRMWOp
950 //===----------------------------------------------------------------------===//
951 
952 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
953  Value memref, ValueRange ivs) {
954  result.addOperands(memref);
955  result.addOperands(ivs);
956 
957  if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
958  Type elementType = memrefType.getElementType();
959  result.addTypes(elementType);
960 
961  Region *bodyRegion = result.addRegion();
962  bodyRegion->push_back(new Block());
963  bodyRegion->addArgument(elementType, memref.getLoc());
964  }
965 }
966 
967 static LogicalResult verify(GenericAtomicRMWOp op) {
968  auto &body = op.getRegion();
969  if (body.getNumArguments() != 1)
970  return op.emitOpError("expected single number of entry block arguments");
971 
972  if (op.getResult().getType() != body.getArgument(0).getType())
973  return op.emitOpError(
974  "expected block argument of the same type result type");
975 
976  bool hasSideEffects =
977  body.walk([&](Operation *nestedOp) {
978  if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
979  return WalkResult::advance();
980  nestedOp->emitError(
981  "body of 'memref.generic_atomic_rmw' should contain "
982  "only operations with no side effects");
983  return WalkResult::interrupt();
984  })
985  .wasInterrupted();
986  return hasSideEffects ? failure() : success();
987 }
988 
990  OperationState &result) {
992  Type memrefType;
993  SmallVector<OpAsmParser::OperandType, 4> ivs;
994 
995  Type indexType = parser.getBuilder().getIndexType();
996  if (parser.parseOperand(memref) ||
998  parser.parseColonType(memrefType) ||
999  parser.resolveOperand(memref, memrefType, result.operands) ||
1000  parser.resolveOperands(ivs, indexType, result.operands))
1001  return failure();
1002 
1003  Region *body = result.addRegion();
1004  if (parser.parseRegion(*body, llvm::None, llvm::None) ||
1005  parser.parseOptionalAttrDict(result.attributes))
1006  return failure();
1007  result.types.push_back(memrefType.cast<MemRefType>().getElementType());
1008  return success();
1009 }
1010 
1011 static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
1012  p << ' ' << op.memref() << "[" << op.indices()
1013  << "] : " << op.memref().getType() << ' ';
1014  p.printRegion(op.getRegion());
1015  p.printOptionalAttrDict(op->getAttrs());
1016 }
1017 
1018 //===----------------------------------------------------------------------===//
1019 // AtomicYieldOp
1020 //===----------------------------------------------------------------------===//
1021 
1022 static LogicalResult verify(AtomicYieldOp op) {
1023  Type parentType = op->getParentOp()->getResultTypes().front();
1024  Type resultType = op.result().getType();
1025  if (parentType != resultType)
1026  return op.emitOpError() << "types mismatch between yield op: " << resultType
1027  << " and its parent: " << parentType;
1028  return success();
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // GlobalOp
1033 //===----------------------------------------------------------------------===//
1034 
1036  TypeAttr type,
1037  Attribute initialValue) {
1038  p << type;
1039  if (!op.isExternal()) {
1040  p << " = ";
1041  if (op.isUninitialized())
1042  p << "uninitialized";
1043  else
1044  p.printAttributeWithoutType(initialValue);
1045  }
1046 }
1047 
1048 static ParseResult
1050  Attribute &initialValue) {
1051  Type type;
1052  if (parser.parseType(type))
1053  return failure();
1054 
1055  auto memrefType = type.dyn_cast<MemRefType>();
1056  if (!memrefType || !memrefType.hasStaticShape())
1057  return parser.emitError(parser.getNameLoc())
1058  << "type should be static shaped memref, but got " << type;
1059  typeAttr = TypeAttr::get(type);
1060 
1061  if (parser.parseOptionalEqual())
1062  return success();
1063 
1064  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1065  initialValue = UnitAttr::get(parser.getContext());
1066  return success();
1067  }
1068 
1069  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1070  if (parser.parseAttribute(initialValue, tensorType))
1071  return failure();
1072  if (!initialValue.isa<ElementsAttr>())
1073  return parser.emitError(parser.getNameLoc())
1074  << "initial value should be a unit or elements attribute";
1075  return success();
1076 }
1077 
1078 static LogicalResult verify(GlobalOp op) {
1079  auto memrefType = op.type().dyn_cast<MemRefType>();
1080  if (!memrefType || !memrefType.hasStaticShape())
1081  return op.emitOpError("type should be static shaped memref, but got ")
1082  << op.type();
1083 
1084  // Verify that the initial value, if present, is either a unit attribute or
1085  // an elements attribute.
1086  if (op.initial_value().hasValue()) {
1087  Attribute initValue = op.initial_value().getValue();
1088  if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1089  return op.emitOpError("initial value should be a unit or elements "
1090  "attribute, but got ")
1091  << initValue;
1092 
1093  // Check that the type of the initial value is compatible with the type of
1094  // the global variable.
1095  if (initValue.isa<ElementsAttr>()) {
1096  Type initType = initValue.getType();
1097  Type tensorType = getTensorTypeFromMemRefType(memrefType);
1098  if (initType != tensorType)
1099  return op.emitOpError("initial value expected to be of type ")
1100  << tensorType << ", but was of type " << initType;
1101  }
1102  }
1103 
1104  if (Optional<uint64_t> alignAttr = op.alignment()) {
1105  uint64_t alignment = alignAttr.getValue();
1106 
1107  if (!llvm::isPowerOf2_64(alignment))
1108  return op->emitError() << "alignment attribute value " << alignment
1109  << " is not a power of 2";
1110  }
1111 
1112  // TODO: verify visibility for declarations.
1113  return success();
1114 }
1115 
1116 //===----------------------------------------------------------------------===//
1117 // GetGlobalOp
1118 //===----------------------------------------------------------------------===//
1119 
1121 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1122  // Verify that the result type is same as the type of the referenced
1123  // memref.global op.
1124  auto global =
1125  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1126  if (!global)
1127  return emitOpError("'")
1128  << name() << "' does not reference a valid global memref";
1129 
1130  Type resultType = result().getType();
1131  if (global.type() != resultType)
1132  return emitOpError("result type ")
1133  << resultType << " does not match type " << global.type()
1134  << " of the global memref @" << name();
1135  return success();
1136 }
1137 
1138 //===----------------------------------------------------------------------===//
1139 // LoadOp
1140 //===----------------------------------------------------------------------===//
1141 
1142 static LogicalResult verify(LoadOp op) {
1143  if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1144  return op.emitOpError("incorrect number of indices for load");
1145  return success();
1146 }
1147 
1148 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1149  /// load(memrefcast) -> load
1150  if (succeeded(foldMemRefCast(*this)))
1151  return getResult();
1152  return OpFoldResult();
1153 }
1154 
1155 //===----------------------------------------------------------------------===//
1156 // PrefetchOp
1157 //===----------------------------------------------------------------------===//
1158 
1159 static void print(OpAsmPrinter &p, PrefetchOp op) {
1160  p << " " << op.memref() << '[';
1161  p.printOperands(op.indices());
1162  p << ']' << ", " << (op.isWrite() ? "write" : "read");
1163  p << ", locality<" << op.localityHint();
1164  p << ">, " << (op.isDataCache() ? "data" : "instr");
1166  op->getAttrs(),
1167  /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1168  p << " : " << op.getMemRefType();
1169 }
1170 
1172  OperationState &result) {
1173  OpAsmParser::OperandType memrefInfo;
1174  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1175  IntegerAttr localityHint;
1176  MemRefType type;
1177  StringRef readOrWrite, cacheType;
1178 
1179  auto indexTy = parser.getBuilder().getIndexType();
1180  auto i32Type = parser.getBuilder().getIntegerType(32);
1181  if (parser.parseOperand(memrefInfo) ||
1182  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1183  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1184  parser.parseComma() || parser.parseKeyword("locality") ||
1185  parser.parseLess() ||
1186  parser.parseAttribute(localityHint, i32Type, "localityHint",
1187  result.attributes) ||
1188  parser.parseGreater() || parser.parseComma() ||
1189  parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1190  parser.resolveOperand(memrefInfo, type, result.operands) ||
1191  parser.resolveOperands(indexInfo, indexTy, result.operands))
1192  return failure();
1193 
1194  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1195  return parser.emitError(parser.getNameLoc(),
1196  "rw specifier has to be 'read' or 'write'");
1197  result.addAttribute(
1198  PrefetchOp::getIsWriteAttrName(),
1199  parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1200 
1201  if (!cacheType.equals("data") && !cacheType.equals("instr"))
1202  return parser.emitError(parser.getNameLoc(),
1203  "cache type has to be 'data' or 'instr'");
1204 
1205  result.addAttribute(
1206  PrefetchOp::getIsDataCacheAttrName(),
1207  parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1208 
1209  return success();
1210 }
1211 
1212 static LogicalResult verify(PrefetchOp op) {
1213  if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1214  return op.emitOpError("too few indices");
1215 
1216  return success();
1217 }
1218 
1219 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1220  SmallVectorImpl<OpFoldResult> &results) {
1221  // prefetch(memrefcast) -> prefetch
1222  return foldMemRefCast(*this);
1223 }
1224 
1225 //===----------------------------------------------------------------------===//
1226 // RankOp
1227 //===----------------------------------------------------------------------===//
1228 
1229 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1230  // Constant fold rank when the rank of the operand is known.
1231  auto type = getOperand().getType();
1232  auto shapedType = type.dyn_cast<ShapedType>();
1233  if (shapedType && shapedType.hasRank())
1234  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1235  return IntegerAttr();
1236 }
1237 
1238 //===----------------------------------------------------------------------===//
1239 // ReinterpretCastOp
1240 //===----------------------------------------------------------------------===//
1241 
1242 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1243 /// `staticSizes` and `staticStrides` are automatically filled with
1244 /// source-memref-rank sentinel values that encode dynamic entries.
1245 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1246  MemRefType resultType, Value source,
1247  OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1248  ArrayRef<OpFoldResult> strides,
1249  ArrayRef<NamedAttribute> attrs) {
1250  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1251  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1252  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1253  ShapedType::kDynamicStrideOrOffset);
1254  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1255  ShapedType::kDynamicSize);
1256  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1257  ShapedType::kDynamicStrideOrOffset);
1258  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1259  dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1260  b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1261  result.addAttributes(attrs);
1262 }
1263 
1264 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1265  MemRefType resultType, Value source,
1266  int64_t offset, ArrayRef<int64_t> sizes,
1267  ArrayRef<int64_t> strides,
1268  ArrayRef<NamedAttribute> attrs) {
1269  SmallVector<OpFoldResult> sizeValues =
1270  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1271  return b.getI64IntegerAttr(v);
1272  }));
1273  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1274  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1275  return b.getI64IntegerAttr(v);
1276  }));
1277  build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1278  strideValues, attrs);
1279 }
1280 
1281 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1282  MemRefType resultType, Value source, Value offset,
1283  ValueRange sizes, ValueRange strides,
1284  ArrayRef<NamedAttribute> attrs) {
1285  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1286  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1287  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1288  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1289  build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1290 }
1291 
1292 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1293 // completed automatically, like we have for subview and extract_slice.
1294 static LogicalResult verify(ReinterpretCastOp op) {
1295  // The source and result memrefs should be in the same memory space.
1296  auto srcType = op.source().getType().cast<BaseMemRefType>();
1297  auto resultType = op.getType().cast<MemRefType>();
1298  if (srcType.getMemorySpace() != resultType.getMemorySpace())
1299  return op.emitError("different memory spaces specified for source type ")
1300  << srcType << " and result memref type " << resultType;
1301  if (srcType.getElementType() != resultType.getElementType())
1302  return op.emitError("different element types specified for source type ")
1303  << srcType << " and result memref type " << resultType;
1304 
1305  // Match sizes in result memref type and in static_sizes attribute.
1306  for (auto &en :
1307  llvm::enumerate(llvm::zip(resultType.getShape(),
1308  extractFromI64ArrayAttr(op.static_sizes())))) {
1309  int64_t resultSize = std::get<0>(en.value());
1310  int64_t expectedSize = std::get<1>(en.value());
1311  if (!ShapedType::isDynamic(resultSize) &&
1312  !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1313  return op.emitError("expected result type with size = ")
1314  << expectedSize << " instead of " << resultSize
1315  << " in dim = " << en.index();
1316  }
1317 
1318  // Match offset and strides in static_offset and static_strides attributes. If
1319  // result memref type has no affine map specified, this will assume an
1320  // identity layout.
1321  int64_t resultOffset;
1322  SmallVector<int64_t, 4> resultStrides;
1323  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1324  return op.emitError(
1325  "expected result type to have strided layout but found ")
1326  << resultType;
1327 
1328  // Match offset in result memref type and in static_offsets attribute.
1329  int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
1330  if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
1331  !ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
1332  resultOffset != expectedOffset)
1333  return op.emitError("expected result type with offset = ")
1334  << resultOffset << " instead of " << expectedOffset;
1335 
1336  // Match strides in result memref type and in static_strides attribute.
1337  for (auto &en : llvm::enumerate(llvm::zip(
1338  resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1339  int64_t resultStride = std::get<0>(en.value());
1340  int64_t expectedStride = std::get<1>(en.value());
1341  if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
1342  !ShapedType::isDynamicStrideOrOffset(expectedStride) &&
1343  resultStride != expectedStride)
1344  return op.emitError("expected result type with stride = ")
1345  << expectedStride << " instead of " << resultStride
1346  << " in dim = " << en.index();
1347  }
1348 
1349  return success();
1350 }
1351 
1352 //===----------------------------------------------------------------------===//
1353 // Reassociative reshape ops
1354 //===----------------------------------------------------------------------===//
1355 
1356 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1357  return getSymbolLessAffineMaps(getReassociationExprs());
1358 }
1359 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1360  return convertReassociationIndicesToExprs(getContext(),
1361  getReassociationIndices());
1362 }
1363 
1364 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1365  return getSymbolLessAffineMaps(getReassociationExprs());
1366 }
1367 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1368  return convertReassociationIndicesToExprs(getContext(),
1369  getReassociationIndices());
1370 }
1371 
1372 static void print(OpAsmPrinter &p, ExpandShapeOp op) {
1373  ::mlir::printReshapeOp<ExpandShapeOp>(p, op);
1374 }
1375 
1376 static void print(OpAsmPrinter &p, CollapseShapeOp op) {
1377  ::mlir::printReshapeOp<CollapseShapeOp>(p, op);
1378 }
1379 
1380 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
1381 /// copies.
1382 static bool isReshapableDimBand(unsigned dim, unsigned extent,
1383  ArrayRef<int64_t> sizes,
1384  ArrayRef<AffineExpr> strides) {
1385  // Bands of extent one can be reshaped, as they are not reshaped at all.
1386  if (extent == 1)
1387  return true;
1388  // Otherwise, the size of the first dimension needs to be known.
1389  if (ShapedType::isDynamic(sizes[dim]))
1390  return false;
1391  assert(sizes.size() == strides.size() && "mismatched ranks");
1392  // off by 1 indexing to avoid out of bounds
1393  // V
1394  for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
1395  // Only bands of static shapes are reshapable. This is due to the fact that
1396  // there is no relation between dynamic sizes and dynamic strides: we do not
1397  // have enough information to know whether a "-1" size corresponds to the
1398  // proper symbol in the AffineExpr of a stride.
1399  if (ShapedType::isDynamic(sizes[idx + 1]))
1400  return false;
1401  // TODO: Refine this by passing the proper nDims and nSymbols so we can
1402  // simplify on the fly and catch more reshapable cases.
1403  if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
1404  return false;
1405  }
1406  return true;
1407 }
1408 
1409 /// Compute the MemRefType obtained by applying the `reassociation` (which is
1410 /// expected to be valid) to `type`.
1411 /// If `type` is Contiguous MemRefType, this always produce a contiguous
1412 /// MemRefType.
1413 static MemRefType
1415  ArrayRef<AffineMap> reassociation) {
1416  auto sizes = type.getShape();
1417  AffineExpr offset;
1418  SmallVector<AffineExpr, 4> strides;
1419  auto status = getStridesAndOffset(type, strides, offset);
1420  auto isIdentityLayout = type.getLayout().isIdentity();
1421  (void)status;
1422  assert(succeeded(status) && "expected strided memref");
1423 
1424  SmallVector<int64_t, 4> newSizes;
1425  newSizes.reserve(reassociation.size());
1426  SmallVector<AffineExpr, 4> newStrides;
1427  newStrides.reserve(reassociation.size());
1428 
1429  // Use the fact that reassociation is valid to simplify the logic: only use
1430  // each map's rank.
1431  assert(isReassociationValid(reassociation) && "invalid reassociation");
1432  unsigned currentDim = 0;
1433  for (AffineMap m : reassociation) {
1434  unsigned dim = m.getNumResults();
1435  int64_t size = 1;
1436  AffineExpr stride = strides[currentDim + dim - 1];
1437  if (isIdentityLayout ||
1438  isReshapableDimBand(currentDim, dim, sizes, strides)) {
1439  for (unsigned d = 0; d < dim; ++d) {
1440  int64_t currentSize = sizes[currentDim + d];
1441  if (ShapedType::isDynamic(currentSize)) {
1442  size = ShapedType::kDynamicSize;
1443  break;
1444  }
1445  size *= currentSize;
1446  }
1447  } else {
1448  size = ShapedType::kDynamicSize;
1449  stride = AffineExpr();
1450  }
1451  newSizes.push_back(size);
1452  newStrides.push_back(stride);
1453  currentDim += dim;
1454  }
1455 
1456  // Early-exit: if `type` is contiguous, the result must be contiguous.
1457  if (canonicalizeStridedLayout(type).getLayout().isIdentity())
1458  return MemRefType::Builder(type).setShape(newSizes).setLayout({});
1459 
1460  // Convert back to int64_t because we don't have enough information to create
1461  // new strided layouts from AffineExpr only. This corresponds to a case where
1462  // copies may be necessary.
1463  int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
1464  if (auto o = offset.dyn_cast<AffineConstantExpr>())
1465  intOffset = o.getValue();
1466  SmallVector<int64_t, 4> intStrides;
1467  intStrides.reserve(strides.size());
1468  for (auto stride : newStrides) {
1469  if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
1470  intStrides.push_back(cst.getValue());
1471  else
1472  intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1473  }
1474  auto layout =
1475  makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
1477  MemRefType::Builder(type).setShape(newSizes).setLayout(
1478  AffineMapAttr::get(layout)));
1479 }
1480 
1481 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1482  ArrayRef<ReassociationIndices> reassociation,
1483  ArrayRef<NamedAttribute> attrs) {
1484  auto memRefType = src.getType().cast<MemRefType>();
1485  auto resultType = computeReshapeCollapsedType(
1487  b.getContext(), reassociation)));
1488  build(b, result, resultType, src, attrs);
1490  getReassociationIndicesAttribute(b, reassociation));
1491 }
1492 
1493 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1494  ArrayRef<ReassociationIndices> reassociation,
1495  ArrayRef<NamedAttribute> attrs) {
1496  auto memRefType = src.getType().cast<MemRefType>();
1497  auto resultType = computeReshapeCollapsedType(
1499  b.getContext(), reassociation)));
1500  build(b, result, resultType, src, attrs);
1502  getReassociationIndicesAttribute(b, reassociation));
1503 }
1504 
1505 template <typename ReshapeOp,
1507 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
1508  MemRefType collapsedType) {
1509  if (failed(
1510  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1511  return failure();
1512  auto maps = op.getReassociationMaps();
1513  MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
1514  if (collapsedType != expectedType)
1515  return op.emitOpError("expected collapsed type to be ")
1516  << expectedType << ", but got " << collapsedType;
1517  return success();
1518 }
1519 
1520 static LogicalResult verify(ExpandShapeOp op) {
1521  return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
1522 }
1523 
1524 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1525  MLIRContext *context) {
1528 }
1529 
1530 static LogicalResult verify(CollapseShapeOp op) {
1531  return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
1532 }
1533 
1535  : public OpRewritePattern<CollapseShapeOp> {
1536 public:
1538 
1539  LogicalResult matchAndRewrite(CollapseShapeOp op,
1540  PatternRewriter &rewriter) const override {
1541  auto cast = op.getOperand().getDefiningOp<CastOp>();
1542  if (!cast)
1543  return failure();
1544 
1545  if (!CastOp::canFoldIntoConsumerOp(cast))
1546  return failure();
1547 
1548  Type newResultType = computeReshapeCollapsedType(
1549  cast.getOperand().getType().cast<MemRefType>(),
1550  op.getReassociationMaps());
1551 
1552  if (newResultType == op.getResultType()) {
1553  rewriter.updateRootInPlace(
1554  op, [&]() { op.srcMutable().assign(cast.source()); });
1555  } else {
1556  Value newOp = rewriter.create<CollapseShapeOp>(
1557  op->getLoc(), cast.source(), op.getReassociationIndices());
1558  rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
1559  }
1560  return success();
1561  }
1562 };
1563 
1564 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1565  MLIRContext *context) {
1569 }
1570 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
1571  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
1572 }
1573 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
1574  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
1575 }
1576 
1577 //===----------------------------------------------------------------------===//
1578 // ReshapeOp
1579 //===----------------------------------------------------------------------===//
1580 
1581 static LogicalResult verify(ReshapeOp op) {
1582  Type operandType = op.source().getType();
1583  Type resultType = op.result().getType();
1584 
1585  Type operandElementType = operandType.cast<ShapedType>().getElementType();
1586  Type resultElementType = resultType.cast<ShapedType>().getElementType();
1587  if (operandElementType != resultElementType)
1588  return op.emitOpError("element types of source and destination memref "
1589  "types should be the same");
1590 
1591  if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1592  if (!operandMemRefType.getLayout().isIdentity())
1593  return op.emitOpError(
1594  "source memref type should have identity affine map");
1595 
1596  int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1597  auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1598  if (resultMemRefType) {
1599  if (!resultMemRefType.getLayout().isIdentity())
1600  return op.emitOpError(
1601  "result memref type should have identity affine map");
1602  if (shapeSize == ShapedType::kDynamicSize)
1603  return op.emitOpError("cannot use shape operand with dynamic length to "
1604  "reshape to statically-ranked memref type");
1605  if (shapeSize != resultMemRefType.getRank())
1606  return op.emitOpError(
1607  "length of shape operand differs from the result's memref rank");
1608  }
1609  return success();
1610 }
1611 
1612 //===----------------------------------------------------------------------===//
1613 // StoreOp
1614 //===----------------------------------------------------------------------===//
1615 
1616 static LogicalResult verify(StoreOp op) {
1617  if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1618  return op.emitOpError("store index operand count not equal to memref rank");
1619 
1620  return success();
1621 }
1622 
1623 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1624  SmallVectorImpl<OpFoldResult> &results) {
1625  /// store(memrefcast) -> store
1626  return foldMemRefCast(*this, getValueToStore());
1627 }
1628 
1629 //===----------------------------------------------------------------------===//
1630 // SubViewOp
1631 //===----------------------------------------------------------------------===//
1632 
1633 namespace {
1634 /// Helpers to write more idiomatic operations.
1635 namespace saturated_arith {
1636 struct Wrapper {
1637  explicit Wrapper(int64_t v) : v(v) {}
1638  operator int64_t() { return v; }
1639  int64_t v;
1640 };
1641 Wrapper operator+(Wrapper a, int64_t b) {
1642  if (ShapedType::isDynamicStrideOrOffset(a) ||
1643  ShapedType::isDynamicStrideOrOffset(b))
1644  return Wrapper(ShapedType::kDynamicStrideOrOffset);
1645  return Wrapper(a.v + b);
1646 }
1647 Wrapper operator*(Wrapper a, int64_t b) {
1648  if (ShapedType::isDynamicStrideOrOffset(a) ||
1649  ShapedType::isDynamicStrideOrOffset(b))
1650  return Wrapper(ShapedType::kDynamicStrideOrOffset);
1651  return Wrapper(a.v * b);
1652 }
1653 } // namespace saturated_arith
1654 } // namespace
1655 
1656 /// A subview result type can be fully inferred from the source type and the
1657 /// static representation of offsets, sizes and strides. Special sentinels
1658 /// encode the dynamic case.
1659 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1660  ArrayRef<int64_t> staticOffsets,
1661  ArrayRef<int64_t> staticSizes,
1662  ArrayRef<int64_t> staticStrides) {
1663  unsigned rank = sourceMemRefType.getRank();
1664  (void)rank;
1665  assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
1666  assert(staticSizes.size() == rank && "staticSizes length mismatch");
1667  assert(staticStrides.size() == rank && "staticStrides length mismatch");
1668 
1669  // Extract source offset and strides.
1670  int64_t sourceOffset;
1671  SmallVector<int64_t, 4> sourceStrides;
1672  auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1673  assert(succeeded(res) && "SubViewOp expected strided memref type");
1674  (void)res;
1675 
1676  // Compute target offset whose value is:
1677  // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1678  int64_t targetOffset = sourceOffset;
1679  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1680  auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1681  using namespace saturated_arith;
1682  targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1683  }
1684 
1685  // Compute target stride whose value is:
1686  // `sourceStrides_i * staticStrides_i`.
1687  SmallVector<int64_t, 4> targetStrides;
1688  targetStrides.reserve(staticOffsets.size());
1689  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1690  auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1691  using namespace saturated_arith;
1692  targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1693  }
1694 
1695  // The type is now known.
1696  return MemRefType::get(
1697  staticSizes, sourceMemRefType.getElementType(),
1698  makeStridedLinearLayoutMap(targetStrides, targetOffset,
1699  sourceMemRefType.getContext()),
1700  sourceMemRefType.getMemorySpace());
1701 }
1702 
1703 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1704  ArrayRef<OpFoldResult> offsets,
1705  ArrayRef<OpFoldResult> sizes,
1706  ArrayRef<OpFoldResult> strides) {
1707  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1708  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1709  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1710  ShapedType::kDynamicStrideOrOffset);
1711  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1712  ShapedType::kDynamicSize);
1713  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1714  ShapedType::kDynamicStrideOrOffset);
1715  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1716  staticSizes, staticStrides);
1717 }
1718 
1719 Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
1720  MemRefType sourceRankedTensorType,
1721  ArrayRef<int64_t> offsets,
1722  ArrayRef<int64_t> sizes,
1723  ArrayRef<int64_t> strides) {
1724  auto inferredType =
1725  inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1726  .cast<MemRefType>();
1727  assert(inferredType.getRank() >= resultRank && "expected ");
1728  int rankDiff = inferredType.getRank() - resultRank;
1729  if (rankDiff > 0) {
1730  auto shape = inferredType.getShape();
1731  llvm::SmallDenseSet<unsigned> dimsToProject;
1732  mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1733  SmallVector<int64_t> projectedShape;
1734  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1735  if (!dimsToProject.contains(pos))
1736  projectedShape.push_back(shape[pos]);
1737 
1738  AffineMap map = inferredType.getLayout().getAffineMap();
1739  if (!map.isIdentity())
1740  map = getProjectedMap(map, dimsToProject);
1741  inferredType =
1742  MemRefType::get(projectedShape, inferredType.getElementType(), map,
1743  inferredType.getMemorySpace());
1744  }
1745  return inferredType;
1746 }
1747 
1748 Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
1749  MemRefType sourceRankedTensorType,
1750  ArrayRef<OpFoldResult> offsets,
1751  ArrayRef<OpFoldResult> sizes,
1752  ArrayRef<OpFoldResult> strides) {
1753  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1754  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1755  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1756  ShapedType::kDynamicStrideOrOffset);
1757  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1758  ShapedType::kDynamicSize);
1759  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1760  ShapedType::kDynamicStrideOrOffset);
1761  return SubViewOp::inferRankReducedResultType(
1762  resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1763  staticStrides);
1764 }
1765 // Build a SubViewOp with mixed static and dynamic entries and custom result
1766 // type. If the type passed is nullptr, it is inferred.
1767 void SubViewOp::build(OpBuilder &b, OperationState &result,
1768  MemRefType resultType, Value source,
1769  ArrayRef<OpFoldResult> offsets,
1770  ArrayRef<OpFoldResult> sizes,
1771  ArrayRef<OpFoldResult> strides,
1772  ArrayRef<NamedAttribute> attrs) {
1773  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1774  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1775  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1776  ShapedType::kDynamicStrideOrOffset);
1777  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1778  ShapedType::kDynamicSize);
1779  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1780  ShapedType::kDynamicStrideOrOffset);
1781  auto sourceMemRefType = source.getType().cast<MemRefType>();
1782  // Structuring implementation this way avoids duplication between builders.
1783  if (!resultType) {
1784  resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1785  staticSizes, staticStrides)
1786  .cast<MemRefType>();
1787  }
1788  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1789  dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1790  b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1791  result.addAttributes(attrs);
1792 }
1793 
1794 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1795 // type.
1796 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1797  ArrayRef<OpFoldResult> offsets,
1798  ArrayRef<OpFoldResult> sizes,
1799  ArrayRef<OpFoldResult> strides,
1800  ArrayRef<NamedAttribute> attrs) {
1801  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1802 }
1803 
1804 // Build a SubViewOp with static entries and inferred result type.
1805 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1806  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1807  ArrayRef<int64_t> strides,
1808  ArrayRef<NamedAttribute> attrs) {
1809  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1810  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1811  return b.getI64IntegerAttr(v);
1812  }));
1813  SmallVector<OpFoldResult> sizeValues =
1814  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1815  return b.getI64IntegerAttr(v);
1816  }));
1817  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1818  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1819  return b.getI64IntegerAttr(v);
1820  }));
1821  build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1822 }
1823 
1824 // Build a SubViewOp with dynamic entries and custom result type. If the
1825 // type passed is nullptr, it is inferred.
1826 void SubViewOp::build(OpBuilder &b, OperationState &result,
1827  MemRefType resultType, Value source,
1828  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1829  ArrayRef<int64_t> strides,
1830  ArrayRef<NamedAttribute> attrs) {
1831  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1832  llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1833  return b.getI64IntegerAttr(v);
1834  }));
1835  SmallVector<OpFoldResult> sizeValues =
1836  llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1837  return b.getI64IntegerAttr(v);
1838  }));
1839  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1840  llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1841  return b.getI64IntegerAttr(v);
1842  }));
1843  build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1844  attrs);
1845 }
1846 
1847 // Build a SubViewOp with dynamic entries and custom result type. If the type
1848 // passed is nullptr, it is inferred.
1849 void SubViewOp::build(OpBuilder &b, OperationState &result,
1850  MemRefType resultType, Value source, ValueRange offsets,
1851  ValueRange sizes, ValueRange strides,
1852  ArrayRef<NamedAttribute> attrs) {
1853  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1854  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1855  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1856  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1857  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1858  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1859  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1860 }
1861 
1862 // Build a SubViewOp with dynamic entries and inferred result type.
1863 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1864  ValueRange offsets, ValueRange sizes, ValueRange strides,
1865  ArrayRef<NamedAttribute> attrs) {
1866  build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1867 }
1868 
1869 /// For ViewLikeOpInterface.
1870 Value SubViewOp::getViewSource() { return source(); }
1871 
1872 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static
1873 /// value).
1874 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
1875  AffineExpr t1Offset, t2Offset;
1876  SmallVector<AffineExpr> t1Strides, t2Strides;
1877  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
1878  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
1879  return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
1880 }
1881 
1882 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1883 /// This function is slight variant of `is subsequence` algorithm where
1884 /// not matching dimension must be 1.
1886 isRankReducedMemRefType(MemRefType originalType,
1887  MemRefType candidateRankReducedType,
1888  ArrayRef<OpFoldResult> sizes) {
1889  auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
1890  if (partialRes != SliceVerificationResult::Success)
1891  return partialRes;
1892 
1893  auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
1894  originalType, candidateRankReducedType, sizes);
1895 
1896  // Sizes cannot be matched in case empty vector is returned.
1897  if (!optionalUnusedDimsMask.hasValue())
1899 
1900  if (originalType.getMemorySpace() !=
1901  candidateRankReducedType.getMemorySpace())
1903 
1904  // No amount of stride dropping can reconcile incompatible offsets.
1905  if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
1907 
1909 }
1910 
1911 template <typename OpTy>
1913  OpTy op, Type expectedType) {
1914  auto memrefType = expectedType.cast<ShapedType>();
1915  switch (result) {
1917  return success();
1919  return op.emitError("expected result rank to be smaller or equal to ")
1920  << "the source rank. ";
1922  return op.emitError("expected result type to be ")
1923  << expectedType
1924  << " or a rank-reduced version. (mismatch of result sizes) ";
1926  return op.emitError("expected result element type to be ")
1927  << memrefType.getElementType();
1929  return op.emitError("expected result and source memory spaces to match.");
1931  return op.emitError("expected result type to be ")
1932  << expectedType
1933  << " or a rank-reduced version. (mismatch of result layout) ";
1934  }
1935  llvm_unreachable("unexpected subview verification result");
1936 }
1937 
1938 /// Verifier for SubViewOp.
1939 static LogicalResult verify(SubViewOp op) {
1940  MemRefType baseType = op.getSourceType();
1941  MemRefType subViewType = op.getType();
1942 
1943  // The base memref and the view memref should be in the same memory space.
1944  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
1945  return op.emitError("different memory spaces specified for base memref "
1946  "type ")
1947  << baseType << " and subview memref type " << subViewType;
1948 
1949  // Verify that the base memref type has a strided layout map.
1950  if (!isStrided(baseType))
1951  return op.emitError("base type ") << baseType << " is not strided";
1952 
1953  // Verify result type against inferred type.
1954  auto expectedType = SubViewOp::inferResultType(
1955  baseType, extractFromI64ArrayAttr(op.static_offsets()),
1956  extractFromI64ArrayAttr(op.static_sizes()),
1957  extractFromI64ArrayAttr(op.static_strides()));
1958 
1959  auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
1960  subViewType, op.getMixedSizes());
1961  return produceSubViewErrorMsg(result, op, expectedType);
1962 }
1963 
1964 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
1965  return os << "range " << range.offset << ":" << range.size << ":"
1966  << range.stride;
1967 }
1968 
1969 /// Return the list of Range (i.e. offset, size, stride). Each Range
1970 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1971 /// with `b` at location `loc`.
1972 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1973  OpBuilder &b, Location loc) {
1974  std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1975  assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1976  assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1978  unsigned rank = ranks[0];
1979  res.reserve(rank);
1980  for (unsigned idx = 0; idx < rank; ++idx) {
1981  Value offset =
1982  op.isDynamicOffset(idx)
1983  ? op.getDynamicOffset(idx)
1984  : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
1985  Value size =
1986  op.isDynamicSize(idx)
1987  ? op.getDynamicSize(idx)
1988  : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
1989  Value stride =
1990  op.isDynamicStride(idx)
1991  ? op.getDynamicStride(idx)
1992  : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
1993  res.emplace_back(Range{offset, size, stride});
1994  }
1995  return res;
1996 }
1997 
1998 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
1999 /// deduce the result type for the given `sourceType`. Additionally, reduce the
2000 /// rank of the inferred result type if `currentResultType` is lower rank than
2001 /// `currentSourceType`. Use this signature if `sourceType` is updated together
2002 /// with the result type. In this case, it is important to compute the dropped
2003 /// dimensions using `currentSourceType` whose strides align with
2004 /// `currentResultType`.
2006  MemRefType currentResultType, MemRefType currentSourceType,
2007  MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2008  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2009  auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2010  mixedSizes, mixedStrides)
2011  .cast<MemRefType>();
2013  computeMemRefRankReductionMask(currentSourceType, currentResultType,
2014  mixedSizes);
2015  // Return nullptr as failure mode.
2016  if (!unusedDims)
2017  return nullptr;
2018  SmallVector<int64_t> shape;
2019  for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) {
2020  if (unusedDims->count(sizes.index()))
2021  continue;
2022  shape.push_back(sizes.value());
2023  }
2024  AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
2025  if (!layoutMap.isIdentity())
2026  layoutMap = getProjectedMap(layoutMap, unusedDims.getValue());
2027  return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
2028  nonRankReducedType.getMemorySpace());
2029 }
2030 
2031 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
2032 /// deduce the result type. Additionally, reduce the rank of the inferred result
2033 /// type if `currentResultType` is lower rank than `sourceType`.
2035  MemRefType currentResultType, MemRefType sourceType,
2036  ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2037  ArrayRef<OpFoldResult> mixedStrides) {
2038  return getCanonicalSubViewResultType(currentResultType, sourceType,
2039  sourceType, mixedOffsets, mixedSizes,
2040  mixedStrides);
2041 }
2042 
2043 /// Helper method to check if a `subview` operation is trivially a no-op. This
2044 /// is the case if the all offsets are zero, all strides are 1, and the source
2045 /// shape is same as the size of the subview. In such cases, the subview can be
2046 /// folded into its source.
2047 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
2048  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2049  return false;
2050 
2051  auto mixedOffsets = subViewOp.getMixedOffsets();
2052  auto mixedSizes = subViewOp.getMixedSizes();
2053  auto mixedStrides = subViewOp.getMixedStrides();
2054 
2055  // Check offsets are zero.
2056  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
2057  Optional<int64_t> intValue = getConstantIntValue(ofr);
2058  return !intValue || intValue.getValue() != 0;
2059  }))
2060  return false;
2061 
2062  // Check strides are one.
2063  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
2064  Optional<int64_t> intValue = getConstantIntValue(ofr);
2065  return !intValue || intValue.getValue() != 1;
2066  }))
2067  return false;
2068 
2069  // Check all size values are static and matches the (static) source shape.
2070  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
2071  for (const auto &size : llvm::enumerate(mixedSizes)) {
2072  Optional<int64_t> intValue = getConstantIntValue(size.value());
2073  if (!intValue || intValue.getValue() != sourceShape[size.index()])
2074  return false;
2075  }
2076  // All conditions met. The `SubViewOp` is foldable as a no-op.
2077  return true;
2078 }
2079 
2080 namespace {
2081 /// Pattern to rewrite a subview op with MemRefCast arguments.
2082 /// This essentially pushes memref.cast past its consuming subview when
2083 /// `canFoldIntoConsumerOp` is true.
2084 ///
2085 /// Example:
2086 /// ```
2087 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2088 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2089 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2090 /// ```
2091 /// is rewritten into:
2092 /// ```
2093 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2094 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2095 /// memref<3x4xf32, offset:?, strides:[?, 1]>
2096 /// ```
2097 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2098 public:
2100 
2101  LogicalResult matchAndRewrite(SubViewOp subViewOp,
2102  PatternRewriter &rewriter) const override {
2103  // Any constant operand, just return to let SubViewOpConstantFolder kick in.
2104  if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2105  return matchPattern(operand, matchConstantIndex());
2106  }))
2107  return failure();
2108 
2109  auto castOp = subViewOp.source().getDefiningOp<CastOp>();
2110  if (!castOp)
2111  return failure();
2112 
2113  if (!CastOp::canFoldIntoConsumerOp(castOp))
2114  return failure();
2115 
2116  // Compute the SubViewOp result type after folding the MemRefCastOp. Use the
2117  // MemRefCastOp source operand type to infer the result type and the current
2118  // SubViewOp source operand type to compute the dropped dimensions if the
2119  // operation is rank-reducing.
2120  auto resultType = getCanonicalSubViewResultType(
2121  subViewOp.getType(), subViewOp.getSourceType(),
2122  castOp.source().getType().cast<MemRefType>(),
2123  subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2124  subViewOp.getMixedStrides());
2125  if (!resultType)
2126  return failure();
2127 
2128  Value newSubView = rewriter.create<SubViewOp>(
2129  subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2130  subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2131  subViewOp.static_sizes(), subViewOp.static_strides());
2132  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2133  newSubView);
2134  return success();
2135  }
2136 };
2137 
2138 /// Canonicalize subview ops that are no-ops. When the source shape is not same
2139 /// as a result shape due to use of `affine_map`.
2140 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
2141 public:
2143 
2144  LogicalResult matchAndRewrite(SubViewOp subViewOp,
2145  PatternRewriter &rewriter) const override {
2146  if (!isTrivialSubViewOp(subViewOp))
2147  return failure();
2148  if (subViewOp.getSourceType() == subViewOp.getType()) {
2149  rewriter.replaceOp(subViewOp, subViewOp.source());
2150  return success();
2151  }
2152  rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.source(),
2153  subViewOp.getType());
2154  return success();
2155  }
2156 };
2157 } // namespace
2158 
2159 /// Return the canonical type of the result of a subview.
2161  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2162  ArrayRef<OpFoldResult> mixedSizes,
2163  ArrayRef<OpFoldResult> mixedStrides) {
2164  return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
2165  mixedOffsets, mixedSizes,
2166  mixedStrides);
2167  }
2168 };
2169 
2170 /// A canonicalizer wrapper to replace SubViewOps.
2172  void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2173  rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
2174  }
2175 };
2176 
2177 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2178  MLIRContext *context) {
2179  results
2182  SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
2183 }
2184 
2185 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2186  auto resultShapedType = getResult().getType().cast<ShapedType>();
2187  auto sourceShapedType = source().getType().cast<ShapedType>();
2188 
2189  if (resultShapedType.hasStaticShape() &&
2190  resultShapedType == sourceShapedType) {
2191  return getViewSource();
2192  }
2193 
2194  return {};
2195 }
2196 
2197 //===----------------------------------------------------------------------===//
2198 // TransposeOp
2199 //===----------------------------------------------------------------------===//
2200 
2201 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
2202 static MemRefType inferTransposeResultType(MemRefType memRefType,
2203  AffineMap permutationMap) {
2204  auto rank = memRefType.getRank();
2205  auto originalSizes = memRefType.getShape();
2206  // Compute permuted sizes.
2207  SmallVector<int64_t, 4> sizes(rank, 0);
2208  for (const auto &en : llvm::enumerate(permutationMap.getResults()))
2209  sizes[en.index()] =
2210  originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2211 
2212  // Compute permuted strides.
2213  int64_t offset;
2214  SmallVector<int64_t, 4> strides;
2215  auto res = getStridesAndOffset(memRefType, strides, offset);
2216  assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2217  (void)res;
2218  auto map =
2219  makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2220  map = permutationMap ? map.compose(permutationMap) : map;
2221  return MemRefType::Builder(memRefType)
2222  .setShape(sizes)
2223  .setLayout(AffineMapAttr::get(map));
2224 }
2225 
2226 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2227  AffineMapAttr permutation,
2228  ArrayRef<NamedAttribute> attrs) {
2229  auto permutationMap = permutation.getValue();
2230  assert(permutationMap);
2231 
2232  auto memRefType = in.getType().cast<MemRefType>();
2233  // Compute result type.
2234  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2235 
2236  build(b, result, resultType, in, attrs);
2237  result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2238 }
2239 
2240 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2241 static void print(OpAsmPrinter &p, TransposeOp op) {
2242  p << " " << op.in() << " " << op.permutation();
2243  p.printOptionalAttrDict(op->getAttrs(),
2244  {TransposeOp::getPermutationAttrName()});
2245  p << " : " << op.in().getType() << " to " << op.getType();
2246 }
2247 
2249  OperationState &result) {
2251  AffineMap permutation;
2252  MemRefType srcType, dstType;
2253  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2254  parser.parseOptionalAttrDict(result.attributes) ||
2255  parser.parseColonType(srcType) ||
2256  parser.resolveOperand(in, srcType, result.operands) ||
2257  parser.parseKeywordType("to", dstType) ||
2258  parser.addTypeToList(dstType, result.types))
2259  return failure();
2260 
2261  result.addAttribute(TransposeOp::getPermutationAttrName(),
2262  AffineMapAttr::get(permutation));
2263  return success();
2264 }
2265 
2266 static LogicalResult verify(TransposeOp op) {
2267  if (!op.permutation().isPermutation())
2268  return op.emitOpError("expected a permutation map");
2269  if (op.permutation().getNumDims() != op.getShapedType().getRank())
2270  return op.emitOpError(
2271  "expected a permutation map of same rank as the input");
2272 
2273  auto srcType = op.in().getType().cast<MemRefType>();
2274  auto dstType = op.getType().cast<MemRefType>();
2275  auto transposedType = inferTransposeResultType(srcType, op.permutation());
2276  if (dstType != transposedType)
2277  return op.emitOpError("output type ")
2278  << dstType << " does not match transposed input type " << srcType
2279  << ", " << transposedType;
2280  return success();
2281 }
2282 
2283 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2284  if (succeeded(foldMemRefCast(*this)))
2285  return getResult();
2286  return {};
2287 }
2288 
2289 //===----------------------------------------------------------------------===//
2290 // ViewOp
2291 //===----------------------------------------------------------------------===//
2292 
2294  OpAsmParser::OperandType srcInfo;
2295  SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2296  SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2297  auto indexType = parser.getBuilder().getIndexType();
2298  Type srcType, dstType;
2299  llvm::SMLoc offsetLoc;
2300  if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2302  return failure();
2303 
2304  if (offsetInfo.size() != 1)
2305  return parser.emitError(offsetLoc) << "expects 1 offset operand";
2306 
2307  return failure(
2308  parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2309  parser.parseOptionalAttrDict(result.attributes) ||
2310  parser.parseColonType(srcType) ||
2311  parser.resolveOperand(srcInfo, srcType, result.operands) ||
2312  parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2313  parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2314  parser.parseKeywordType("to", dstType) ||
2315  parser.addTypeToList(dstType, result.types));
2316 }
2317 
2318 static void print(OpAsmPrinter &p, ViewOp op) {
2319  p << ' ' << op.getOperand(0) << '[';
2320  p.printOperand(op.byte_shift());
2321  p << "][" << op.sizes() << ']';
2322  p.printOptionalAttrDict(op->getAttrs());
2323  p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2324 }
2325 
2326 static LogicalResult verify(ViewOp op) {
2327  auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2328  auto viewType = op.getType();
2329 
2330  // The base memref should have identity layout map (or none).
2331  if (!baseType.getLayout().isIdentity())
2332  return op.emitError("unsupported map for base memref type ") << baseType;
2333 
2334  // The result memref should have identity layout map (or none).
2335  if (!viewType.getLayout().isIdentity())
2336  return op.emitError("unsupported map for result memref type ") << viewType;
2337 
2338  // The base memref and the view memref should be in the same memory space.
2339  if (baseType.getMemorySpace() != viewType.getMemorySpace())
2340  return op.emitError("different memory spaces specified for base memref "
2341  "type ")
2342  << baseType << " and view memref type " << viewType;
2343 
2344  // Verify that we have the correct number of sizes for the result type.
2345  unsigned numDynamicDims = viewType.getNumDynamicDims();
2346  if (op.sizes().size() != numDynamicDims)
2347  return op.emitError("incorrect number of size operands for type ")
2348  << viewType;
2349 
2350  return success();
2351 }
2352 
2353 Value ViewOp::getViewSource() { return source(); }
2354 
2355 namespace {
2356 
2357 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2359 
2360  LogicalResult matchAndRewrite(ViewOp viewOp,
2361  PatternRewriter &rewriter) const override {
2362  // Return if none of the operands are constants.
2363  if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2364  return matchPattern(operand, matchConstantIndex());
2365  }))
2366  return failure();
2367 
2368  // Get result memref type.
2369  auto memrefType = viewOp.getType();
2370 
2371  // Get offset from old memref view type 'memRefType'.
2372  int64_t oldOffset;
2373  SmallVector<int64_t, 4> oldStrides;
2374  if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2375  return failure();
2376  assert(oldOffset == 0 && "Expected 0 offset");
2377 
2378  SmallVector<Value, 4> newOperands;
2379 
2380  // Offset cannot be folded into result type.
2381 
2382  // Fold any dynamic dim operands which are produced by a constant.
2383  SmallVector<int64_t, 4> newShapeConstants;
2384  newShapeConstants.reserve(memrefType.getRank());
2385 
2386  unsigned dynamicDimPos = 0;
2387  unsigned rank = memrefType.getRank();
2388  for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2389  int64_t dimSize = memrefType.getDimSize(dim);
2390  // If this is already static dimension, keep it.
2391  if (!ShapedType::isDynamic(dimSize)) {
2392  newShapeConstants.push_back(dimSize);
2393  continue;
2394  }
2395  auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2396  if (auto constantIndexOp =
2397  dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
2398  // Dynamic shape dimension will be folded.
2399  newShapeConstants.push_back(constantIndexOp.value());
2400  } else {
2401  // Dynamic shape dimension not folded; copy operand from old memref.
2402  newShapeConstants.push_back(dimSize);
2403  newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2404  }
2405  dynamicDimPos++;
2406  }
2407 
2408  // Create new memref type with constant folded dims.
2409  MemRefType newMemRefType =
2410  MemRefType::Builder(memrefType).setShape(newShapeConstants);
2411  // Nothing new, don't fold.
2412  if (newMemRefType == memrefType)
2413  return failure();
2414 
2415  // Create new ViewOp.
2416  auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2417  viewOp.getOperand(0),
2418  viewOp.byte_shift(), newOperands);
2419  // Insert a cast so we have the same type as the old memref type.
2420  rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2421  return success();
2422  }
2423 };
2424 
2425 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2427 
2428  LogicalResult matchAndRewrite(ViewOp viewOp,
2429  PatternRewriter &rewriter) const override {
2430  Value memrefOperand = viewOp.getOperand(0);
2431  CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2432  if (!memrefCastOp)
2433  return failure();
2434  Value allocOperand = memrefCastOp.getOperand();
2435  AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2436  if (!allocOp)
2437  return failure();
2438  rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2439  viewOp.byte_shift(), viewOp.sizes());
2440  return success();
2441  }
2442 };
2443 
2444 } // namespace
2445 
2446 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2447  MLIRContext *context) {
2448  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2449 }
2450 
2451 //===----------------------------------------------------------------------===//
2452 // AtomicRMWOp
2453 //===----------------------------------------------------------------------===//
2454 
2455 static LogicalResult verify(AtomicRMWOp op) {
2456  if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
2457  return op.emitOpError(
2458  "expects the number of subscripts to be equal to memref rank");
2459  switch (op.kind()) {
2460  case arith::AtomicRMWKind::addf:
2461  case arith::AtomicRMWKind::maxf:
2462  case arith::AtomicRMWKind::minf:
2463  case arith::AtomicRMWKind::mulf:
2464  if (!op.value().getType().isa<FloatType>())
2465  return op.emitOpError()
2466  << "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
2467  << "' expects a floating-point type";
2468  break;
2469  case arith::AtomicRMWKind::addi:
2470  case arith::AtomicRMWKind::maxs:
2471  case arith::AtomicRMWKind::maxu:
2472  case arith::AtomicRMWKind::mins:
2473  case arith::AtomicRMWKind::minu:
2474  case arith::AtomicRMWKind::muli:
2475  case arith::AtomicRMWKind::ori:
2476  case arith::AtomicRMWKind::andi:
2477  if (!op.value().getType().isa<IntegerType>())
2478  return op.emitOpError()
2479  << "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
2480  << "' expects an integer type";
2481  break;
2482  default:
2483  break;
2484  }
2485  return success();
2486 }
2487 
2488 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
2489  /// atomicrmw(memrefcast) -> atomicrmw
2490  if (succeeded(foldMemRefCast(*this, value())))
2491  return getResult();
2492  return OpFoldResult();
2493 }
2494 
2495 //===----------------------------------------------------------------------===//
2496 // TableGen'd op method definitions
2497 //===----------------------------------------------------------------------===//
2498 
2499 #define GET_OP_CLASSES
2500 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
virtual ParseResult parseOperand(OperandType &result)=0
Parse a single operand.
This is the representation of an operand reference.
Include the generated interface declarations.
Pattern to rewrite a subview op with constant arguments.
Definition: Utils.h:40
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType)
Definition: MemRefOps.cpp:1912
ParseResult resolveOperands(ArrayRef< OperandType > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
MLIRContext * getContext() const
Definition: Builders.h:54
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into &#39;map&#39;.
static ParseResult parsePrefetchOp(OpAsmParser &parser, OperationState &result)
Definition: MemRefOps.cpp:1171
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:444
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
Definition: MemRefOps.cpp:1874
Block represents an ordered list of Operations.
Definition: Block.h:29
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec, int64_t sentinel)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
A trait of region holding operations that define a new scope for automatic allocations, i.e., allocations that are freed when control is transferred back from the operation&#39;s region.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:244
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
Definition: MemRefOps.cpp:1539
bool isa() const
Definition: Attributes.h:107
void push_back(Block *block)
Definition: Region.h:61
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape, llvm::SmallDenseSet< unsigned > &dimsToProject)
Definition: Utils.cpp:42
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:62
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static bool isReshapableDimBand(unsigned dim, unsigned extent, ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > strides)
Detect whether memref dims [dim, dim + extent) can be reshaped without copies.
Definition: MemRefOps.cpp:1382
A canonicalizer wrapper to replace SubViewOps.
Definition: MemRefOps.cpp:2171
AffineMap getProjectedMap(AffineMap map, const llvm::SmallDenseSet< unsigned > &projectedDimensions)
Returns the map that results from projecting out the dimensions specified in projectedDimensions.
Definition: AffineMap.cpp:736
virtual ParseResult parseTrailingOperandList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual ParseResult parseComma()=0
Parse a , token.
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
Helpers to write more idiomatic operations.
Definition: MemRefOps.cpp:1635
static constexpr const bool value
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of...
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Auxiliary range data structure to unpack the offset, size and stride operands into a list of triples...
virtual ParseResult parseOperandList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:343
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
static ParseResult parseAllocaScopeOp(OpAsmParser &parser, OperationState &result)
Definition: MemRefOps.cpp:226
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:252
U dyn_cast() const
Definition: AffineExpr.h:281
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual ParseResult resolveOperand(const OperandType &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< OpFoldResult, 4 > getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, ValueRange sizes)
Return a vector of all the static or dynamic sizes of the op from provided external static and dynami...
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:242
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
virtual ParseResult parseRegion(Region &region, ArrayRef< OperandType > arguments={}, ArrayRef< Type > argTypes={}, ArrayRef< Location > argLocations={}, bool enableNameShadowing=false)=0
Parses a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it...
Definition: MemRefOps.cpp:567
virtual ParseResult parseGreater()=0
Parse a &#39;>&#39; token.
void addOperands(ValueRange newOperands)
virtual llvm::SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
Definition: MemRefOps.cpp:75
U dyn_cast() const
Definition: Types.h:244
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:334
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
Base type for affine expression.
Definition: AffineExpr.h:68
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
void addTypes(ArrayRef< Type > newTypes)
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
virtual ParseResult parseLess()=0
Parse a &#39;<&#39; token.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static WalkResult advance()
Definition: Visitors.h:51
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:789
This represents an operation in an abstracted form, suitable for use with the builder APIs...
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
static WalkResult interrupt()
Definition: Visitors.h:50
Fraction operator*(Fraction x, Fraction y)
Definition: Fraction.h:73
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:311
static void print(OpAsmPrinter &p, AllocaScopeOp &op)
Definition: MemRefOps.cpp:211
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:91
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, MemRefType collapsedType)
Definition: MemRefOps.cpp:1507
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
AffineExpr operator+(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:244
static ParseResult parseTransposeOp(OpAsmParser &parser, OperationState &result)
Definition: MemRefOps.cpp:2248
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:1972
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Definition: MemRefOps.cpp:2172
static bool hasSideEffects(Operation *op)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
Definition: MemRefOps.cpp:2005
static int resultIndex(int i)
Definition: Operator.cpp:308
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
static llvm::Optional< llvm::SmallDenseSet< unsigned > > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
Definition: MemRefOps.cpp:582
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
IndexType getIndexType()
Definition: Builders.cpp:48
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:109
static ParseResult parseDmaStartOp(OpAsmParser &parser, OperationState &result)
Definition: MemRefOps.cpp:793
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
static MemRefType computeReshapeCollapsedType(MemRefType type, ArrayRef< AffineMap > reassociation)
Compute the MemRefType obtained by applying the reassociation (which is expected to be valid) to type...
Definition: MemRefOps.cpp:1414
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:161
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: MemRefOps.cpp:1035
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
Definition: MemRefOps.cpp:2047
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:255
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
virtual ParseResult parseType(Type &result)=0
Parse a type.
U dyn_cast_or_null() const
Definition: Value.h:103
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
AffineMap makeStridedLinearLayoutMap(ArrayRef< int64_t > strides, int64_t offset, MLIRContext *context)
Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() represents a dynamic value)...
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap tp memRefType.
Definition: MemRefOps.cpp:2202
This class represents an operand of an operation.
Definition: Value.h:249
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: MemRefOps.cpp:2161
static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser, OperationState &result)
Definition: MemRefOps.cpp:989
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:182
static SliceVerificationResult isRankReducedMemRefType(MemRefType originalType, MemRefType candidateRankReducedType, ArrayRef< OpFoldResult > sizes)
Checks if original Type type can be rank reduced to reduced type.
Definition: MemRefOps.cpp:1886
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
bool isa() const
Definition: Types.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:61
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result)
Definition: MemRefOps.cpp:2293
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:47
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: MemRefOps.cpp:1049
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:172
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
Return the canonical type of the result of a subview.
Definition: MemRefOps.cpp:2160
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
Square brackets surrounding zero or more operands.
U cast() const
Definition: Types.h:250
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:23