MLIR  20.0.0git
Shape.cpp
Go to the documentation of this file.
1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 
17 #include "mlir/Dialect/Traits.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/SetOperations.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 using namespace mlir;
33 using namespace mlir::shape;
34 
35 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
36 
37 namespace {
38 #include "ShapeCanonicalization.inc"
39 } // namespace
40 
41 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
42  return RankedTensorType::get({rank}, IndexType::get(ctx));
43 }
44 
46  auto ranked = llvm::dyn_cast<RankedTensorType>(type);
47  return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
48 }
49 
50 LogicalResult shape::getShapeVec(Value input,
51  SmallVectorImpl<int64_t> &shapeValues) {
52  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
53  auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
54  if (!type.hasRank())
55  return failure();
56  llvm::append_range(shapeValues, type.getShape());
57  return success();
58  }
60  if (matchPattern(input, m_Constant(&attr))) {
61  llvm::append_range(shapeValues, attr.getValues<int64_t>());
62  return success();
63  }
64  return failure();
65 }
66 
67 static bool isErrorPropagationPossible(TypeRange operandTypes) {
68  return llvm::any_of(operandTypes,
69  llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
70 }
71 
72 static LogicalResult verifySizeOrIndexOp(Operation *op) {
73  assert(op != nullptr && op->getNumResults() == 1);
74  Type resultTy = op->getResultTypes().front();
76  if (!llvm::isa<SizeType>(resultTy))
77  return op->emitOpError()
78  << "if at least one of the operands can hold error values then "
79  "the result must be of type `size` to propagate them";
80  }
81  return success();
82 }
83 
84 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
85  assert(op != nullptr && op->getNumResults() == 1);
86  Type resultTy = op->getResultTypes().front();
88  if (!llvm::isa<ShapeType>(resultTy))
89  return op->emitOpError()
90  << "if at least one of the operands can hold error values then "
91  "the result must be of type `shape` to propagate them";
92  }
93  return success();
94 }
95 
96 template <typename... Ty>
97 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
98  return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
99 }
100 
101 template <typename... Ty, typename... ranges>
102 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
103  return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // InlinerInterface
108 //===----------------------------------------------------------------------===//
109 
110 namespace {
111 /// This class defines the interface for inlining shape dialect ops.
112 struct ShapeInlinerInterface : public DialectInlinerInterface {
114 
115  // Returns true if the given region 'src' can be inlined into the region
116  // 'dest' that is attached to an operation registered to the current dialect.
117  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
118  IRMapping &) const final {
119  return true;
120  }
121 
122  // Returns true if the given operation 'op', that is registered to this
123  // dialect, can be inlined into the region 'dest' that is attached to an
124  // operation registered to the current dialect.
125  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
126  IRMapping &) const final {
127  return true;
128  }
129 };
130 } // namespace
131 
132 void ShapeDialect::initialize() {
133  addOperations<
134 #define GET_OP_LIST
135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
136  >();
137  addTypes<
138 #define GET_TYPEDEF_LIST
139 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
140  >();
141  addInterfaces<ShapeInlinerInterface>();
142  // Allow unknown operations during prototyping and testing. As the dialect is
143  // still evolving it makes it simple to start with an unregistered ops and
144  // try different variants before actually defining the op.
145  allowUnknownOperations();
146  declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
147  AssumingYieldOp>();
148 }
149 
151  Attribute value, Type type,
152  Location loc) {
153  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
154  return builder.create<ub::PoisonOp>(loc, type, poison);
155 
156  if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
157  return builder.create<ConstShapeOp>(
158  loc, type, llvm::cast<DenseIntElementsAttr>(value));
159  if (llvm::isa<SizeType>(type))
160  return builder.create<ConstSizeOp>(loc, type,
161  llvm::cast<IntegerAttr>(value));
162  if (llvm::isa<WitnessType>(type))
163  return builder.create<ConstWitnessOp>(loc, type,
164  llvm::cast<BoolAttr>(value));
165 
166  return arith::ConstantOp::materialize(builder, value, type, loc);
167 }
168 
169 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
170  NamedAttribute attribute) {
171  // Verify shape.lib attribute.
172  if (attribute.getName() == "shape.lib") {
173  if (!op->hasTrait<OpTrait::SymbolTable>())
174  return op->emitError(
175  "shape.lib attribute may only be on op implementing SymbolTable");
176 
177  if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
178  auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
179  if (!symbol)
180  return op->emitError("shape function library ")
181  << symbolRef << " not found";
182  return isa<shape::FunctionLibraryOp>(symbol)
183  ? success()
184  : op->emitError()
185  << symbolRef << " required to be shape function library";
186  }
187 
188  if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
189  // Verify all entries are function libraries and mappings in libraries
190  // refer to unique ops.
192  for (auto it : arr) {
193  if (!llvm::isa<SymbolRefAttr>(it))
194  return op->emitError(
195  "only SymbolRefAttr allowed in shape.lib attribute array");
196 
197  auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
198  SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
199  if (!shapeFnLib)
200  return op->emitError()
201  << it << " does not refer to FunctionLibraryOp";
202  for (auto mapping : shapeFnLib.getMapping()) {
203  if (!key.insert(mapping.getName()).second) {
204  return op->emitError("only one op to shape mapping allowed, found "
205  "multiple for `")
206  << mapping.getName() << "`";
207  }
208  }
209  }
210  return success();
211  }
212 
213  return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
214  "allowed as shape.lib attribute");
215  }
216  return success();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // AnyOp
221 //===----------------------------------------------------------------------===//
222 
223 // TODO: Canonicalization should be implemented for shapes that can be
224 // determined through mixtures of the known dimensions of the inputs.
225 OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
226  // Only the last operand is checked because AnyOp is commutative.
227  if (adaptor.getInputs().back())
228  return adaptor.getInputs().back();
229 
230  return nullptr;
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // AssumingOp
235 //===----------------------------------------------------------------------===//
236 
237 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
238  result.regions.reserve(1);
239  Region *doRegion = result.addRegion();
240 
241  auto &builder = parser.getBuilder();
243  if (parser.parseOperand(cond) ||
244  parser.resolveOperand(cond, builder.getType<WitnessType>(),
245  result.operands))
246  return failure();
247 
248  // Parse optional results type list.
249  if (parser.parseOptionalArrowTypeList(result.types))
250  return failure();
251 
252  // Parse the region and add a terminator if elided.
253  if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
254  return failure();
255  AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
256 
257  // Parse the optional attribute list.
258  if (parser.parseOptionalAttrDict(result.attributes))
259  return failure();
260  return success();
261 }
262 
264  bool yieldsResults = !getResults().empty();
265 
266  p << " " << getWitness();
267  if (yieldsResults)
268  p << " -> (" << getResultTypes() << ")";
269  p << ' ';
270  p.printRegion(getDoRegion(),
271  /*printEntryBlockArgs=*/false,
272  /*printBlockTerminators=*/yieldsResults);
273  p.printOptionalAttrDict((*this)->getAttrs());
274 }
275 
276 namespace {
277 // Removes AssumingOp with a passing witness and inlines the region.
278 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
280 
281  LogicalResult matchAndRewrite(AssumingOp op,
282  PatternRewriter &rewriter) const override {
283  auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
284  if (!witness || !witness.getPassingAttr())
285  return failure();
286 
287  AssumingOp::inlineRegionIntoParent(op, rewriter);
288  return success();
289  }
290 };
291 
292 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
294 
295  LogicalResult matchAndRewrite(AssumingOp op,
296  PatternRewriter &rewriter) const override {
297  Block *body = op.getBody();
298  auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
299 
300  // Find used values.
301  SmallVector<Value, 4> newYieldOperands;
302  for (auto [opResult, yieldOperand] :
303  llvm::zip(op.getResults(), yieldOp.getOperands())) {
304  if (!opResult.getUses().empty()) {
305  newYieldOperands.push_back(yieldOperand);
306  }
307  }
308 
309  // Rewrite only if redundant results exist.
310  if (newYieldOperands.size() == yieldOp->getNumOperands())
311  return failure();
312 
313  // Replace yield op in the old assuming op's body and move the entire region
314  // to the new assuming op.
315  rewriter.setInsertionPointToEnd(body);
316  auto newYieldOp =
317  rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
318  rewriter.setInsertionPoint(op);
319  auto newOp = rewriter.create<AssumingOp>(
320  op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
321  newOp.getDoRegion().takeBody(op.getDoRegion());
322 
323  // Use the new results to replace the previously used ones.
324  SmallVector<Value, 4> replacementValues;
325  auto src = newOp.getResults().begin();
326  for (auto it : op.getResults()) {
327  if (it.getUses().empty())
328  replacementValues.push_back(nullptr);
329  else
330  replacementValues.push_back(*src++);
331  }
332  rewriter.replaceOp(op, replacementValues);
333  return success();
334  }
335 };
336 } // namespace
337 
338 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
339  MLIRContext *context) {
340  patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
341 }
342 
343 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
344 void AssumingOp::getSuccessorRegions(
346  // AssumingOp has unconditional control flow into the region and back to the
347  // parent, so return the correct RegionSuccessor purely based on the index
348  // being None or 0.
349  if (!point.isParent()) {
350  regions.push_back(RegionSuccessor(getResults()));
351  return;
352  }
353 
354  regions.push_back(RegionSuccessor(&getDoRegion()));
355 }
356 
357 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
358  PatternRewriter &rewriter) {
359  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
360  auto *assumingBlock = op.getBody();
361  auto initPosition = rewriter.getInsertionPoint();
362  auto *blockAfterAssuming =
363  rewriter.splitBlock(blockBeforeAssuming, initPosition);
364 
365  // Remove the AssumingOp and AssumingYieldOp.
366  auto &yieldOp = assumingBlock->back();
367  rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
368  rewriter.replaceOp(op, yieldOp.getOperands());
369  rewriter.eraseOp(&yieldOp);
370 
371  // Merge blocks together as there was no branching behavior from the
372  // AssumingOp.
373  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
374  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
375 }
376 
377 void AssumingOp::build(
378  OpBuilder &builder, OperationState &result, Value witness,
380  OpBuilder::InsertionGuard g(builder);
381 
382  result.addOperands(witness);
383  Region *bodyRegion = result.addRegion();
384  builder.createBlock(bodyRegion);
385 
386  // Build body.
387  SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
388  builder.create<AssumingYieldOp>(result.location, yieldValues);
389 
390  SmallVector<Type, 2> assumingTypes;
391  for (Value v : yieldValues)
392  assumingTypes.push_back(v.getType());
393  result.addTypes(assumingTypes);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // AddOp
398 //===----------------------------------------------------------------------===//
399 
400 LogicalResult mlir::shape::AddOp::inferReturnTypes(
401  MLIRContext *context, std::optional<Location> location,
402  AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
403  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
404  llvm::isa<SizeType>(adaptor.getRhs().getType()))
405  inferredReturnTypes.assign({SizeType::get(context)});
406  else
407  inferredReturnTypes.assign({IndexType::get(context)});
408  return success();
409 }
410 
411 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
412  // SizeType is compatible with IndexType.
413  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
414 }
415 
416 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
417  // add(x, 0) -> x
418  if (matchPattern(getRhs(), m_Zero()))
419  return getLhs();
420 
421  return constFoldBinaryOp<IntegerAttr>(
422  adaptor.getOperands(),
423  [](APInt a, const APInt &b) { return std::move(a) + b; });
424 }
425 
426 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
427 
428 //===----------------------------------------------------------------------===//
429 // AssumingAllOp
430 //===----------------------------------------------------------------------===//
431 
432 namespace {
433 
434 // Merge multiple `shape.assuming_all` operations together.
435 //
436 // %0 = shape.assuming_all %w0, %w1
437 // %1 = shape.assuming_all %w2, %0
438 //
439 // to:
440 //
441 // %0 = shape.assuming_all %w0, %w2, %w2
442 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
444 
445  LogicalResult matchAndRewrite(AssumingAllOp op,
446  PatternRewriter &rewriter) const override {
447  SmallVector<Value> operands;
448 
449  for (Value operand : op.getInputs()) {
450  if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
451  operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
452  else
453  operands.push_back(operand);
454  }
455 
456  // We didn't find any other `assuming_all` ops to merge with.
457  if (operands.size() == op.getNumOperands())
458  return failure();
459 
460  // Replace with a new `assuming_all` operation with merged constraints.
461  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
462  return success();
463  }
464 };
465 
466 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
467 // are subsumed by others.
468 //
469 // %0 = shape.cstr_broadcastable %shape0, %shape1
470 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
471 //
472 // %2 = shape.cstr_broadcastable %shape3, %shape4
473 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
474 //
475 // %4 = shape.assuming_all %0, %1, %2, %3
476 //
477 // to:
478 //
479 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
480 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
481 // %2 = shape.assuming_all %0, %1
482 //
483 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
484 // shapes [0, 1] are broadcastable too, and can be removed from the list of
485 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
486 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
487 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
489 
490  LogicalResult matchAndRewrite(AssumingAllOp op,
491  PatternRewriter &rewriter) const override {
492  // Collect all `CstrBroadcastableOp` operands first.
494  for (Value operand : op.getInputs()) {
495  // TODO: Apply this optimization if some of the witnesses are not
496  // produced by the `cstr_broadcastable`.
497  auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
498  if (!broadcastable)
499  return failure();
500 
501  operands.insert(broadcastable);
502  }
503 
504  // Skip trivial `assuming_all` operations.
505  if (operands.size() <= 1)
506  return failure();
507 
508  // Collect shapes checked by `cstr_broadcastable` operands.
510  for (auto cstr : operands) {
511  DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
512  shapes.emplace_back(cstr, std::move(shapesSet));
513  }
514 
515  // Sort by the number of shape operands (larger to smaller).
516  llvm::sort(shapes, [](auto a, auto b) {
517  return a.first.getNumOperands() > b.first.getNumOperands();
518  });
519 
520  // We start from the `cst_broadcastable` operations with largest number of
521  // shape operands, and remove redundant `cst_broadcastable` operations. We
522  // do this until we find a set of `cst_broadcastable` operations with
523  // non-overlapping constraints.
524  SmallVector<CstrBroadcastableOp> markedForErase;
525 
526  for (unsigned i = 0; i < shapes.size(); ++i) {
527  auto isSubset = [&](auto pair) {
528  return llvm::set_is_subset(pair.second, shapes[i].second);
529  };
530 
531  // Keep redundant `cstr_broadcastable` operations to be erased.
532  auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
533  for (auto *it0 = it; it0 < shapes.end(); ++it0)
534  markedForErase.push_back(it0->first);
535  shapes.erase(it, shapes.end());
536  }
537 
538  // We didn't find any operands that could be removed.
539  if (markedForErase.empty())
540  return failure();
541 
542  // Collect non-overlapping `cst_broadcastable` constraints.
543  SmallVector<Value> uniqueConstraints;
544  for (auto &shape : shapes)
545  uniqueConstraints.push_back(shape.first.getResult());
546 
547  // Replace with a new `assuming_all` operation ...
548  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
549 
550  // ... and maybe erase `cstr_broadcastable` ops without uses.
551  for (auto &op : markedForErase)
552  if (op->use_empty())
553  rewriter.eraseOp(op);
554 
555  return success();
556  }
557 };
558 
559 struct AssumingAllToCstrEqCanonicalization
560  : public OpRewritePattern<AssumingAllOp> {
562 
563  LogicalResult matchAndRewrite(AssumingAllOp op,
564  PatternRewriter &rewriter) const override {
565  SmallVector<Value, 8> shapes;
566  for (Value w : op.getInputs()) {
567  auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
568  if (!cstrEqOp)
569  return failure();
570  bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
571  return llvm::is_contained(shapes, s);
572  });
573  if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
574  return failure();
575  shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
576  }
577  rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
578  return success();
579  }
580 };
581 
582 template <typename OpTy>
583 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
585 
586  LogicalResult matchAndRewrite(OpTy op,
587  PatternRewriter &rewriter) const override {
588  // Find unique operands.
589  SetVector<Value> unique(op.operand_begin(), op.operand_end());
590 
591  // Reduce op to equivalent with unique operands.
592  if (unique.size() < op.getNumOperands()) {
593  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
594  unique.takeVector(), op->getAttrs());
595  return success();
596  }
597 
598  return failure();
599  }
600 };
601 } // namespace
602 
603 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
604  MLIRContext *context) {
605  patterns
606  .add<MergeAssumingAllOps, AssumingAllOneOp,
607  AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
608  RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
609 }
610 
611 OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
612  // Iterate in reverse to first handle all constant operands. They are
613  // guaranteed to be the tail of the inputs because this is commutative.
614  for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
615  Attribute a = adaptor.getInputs()[idx];
616  // Cannot fold if any inputs are not constant;
617  if (!a)
618  return nullptr;
619 
620  // We do not need to keep statically known values after handling them in
621  // this method.
622  getOperation()->eraseOperand(idx);
623 
624  // Always false if any input is statically known false
625  if (!llvm::cast<BoolAttr>(a).getValue())
626  return a;
627  }
628  // If this is reached, all inputs were statically known passing.
629  return BoolAttr::get(getContext(), true);
630 }
631 
632 LogicalResult AssumingAllOp::verify() {
633  // Ensure that AssumingAllOp contains at least one operand
634  if (getNumOperands() == 0)
635  return emitOpError("no operands specified");
636 
637  return success();
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // BroadcastOp
642 //===----------------------------------------------------------------------===//
643 
644 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
645  if (getShapes().size() == 1) {
646  // Otherwise, we need a cast which would be a canonicalization, not folding.
647  if (getShapes().front().getType() != getType())
648  return nullptr;
649  return getShapes().front();
650  }
651 
652  // TODO: Support folding with more than 2 input shapes
653  if (getShapes().size() > 2)
654  return nullptr;
655 
656  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
657  return nullptr;
658  auto lhsShape = llvm::to_vector<6>(
659  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
660  .getValues<int64_t>());
661  auto rhsShape = llvm::to_vector<6>(
662  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
663  .getValues<int64_t>());
664  SmallVector<int64_t, 6> resultShape;
665 
666  // If the shapes are not compatible, we can't fold it.
667  // TODO: Fold to an "error".
668  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
669  return nullptr;
670 
671  Builder builder(getContext());
672  return builder.getIndexTensorAttr(resultShape);
673 }
674 
675 LogicalResult BroadcastOp::verify() {
676  return verifyShapeOrExtentTensorOp(*this);
677 }
678 
679 namespace {
680 template <typename OpTy>
681 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
683 
684  LogicalResult matchAndRewrite(OpTy op,
685  PatternRewriter &rewriter) const override {
686  auto isPotentiallyNonEmptyShape = [](Value shape) {
687  if (auto extentTensorTy =
688  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
689  if (extentTensorTy.getDimSize(0) == 0)
690  return false;
691  }
692  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
693  if (constShape.getShape().empty())
694  return false;
695  }
696  return true;
697  };
698  auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
699  isPotentiallyNonEmptyShape);
700 
701  // Replace the op with empty shape constant if all operants are reduced to
702  // be empty.
703  if (newOperands.empty()) {
704  rewriter.replaceOpWithNewOp<ConstShapeOp>(
705  op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
706  return success();
707  }
708 
709  // Reduce op to equivalent without empty shape operands.
710  if (newOperands.size() < op.getNumOperands()) {
711  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
712  op->getAttrs());
713  return success();
714  }
715 
716  return failure();
717  }
718 };
719 
720 struct BroadcastForwardSingleOperandPattern
721  : public OpRewritePattern<BroadcastOp> {
723 
724  LogicalResult matchAndRewrite(BroadcastOp op,
725  PatternRewriter &rewriter) const override {
726  if (op.getNumOperands() != 1)
727  return failure();
728  Value replacement = op.getShapes().front();
729 
730  // Insert cast if needed.
731  if (replacement.getType() != op.getType()) {
732  auto loc = op.getLoc();
733  if (llvm::isa<ShapeType>(op.getType())) {
734  replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
735  } else {
736  assert(!llvm::isa<ShapeType>(op.getType()) &&
737  !llvm::isa<ShapeType>(replacement.getType()) &&
738  "expect extent tensor cast");
739  replacement =
740  rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
741  }
742  }
743 
744  rewriter.replaceOp(op, replacement);
745  return success();
746  }
747 };
748 
749 struct BroadcastFoldConstantOperandsPattern
750  : public OpRewritePattern<BroadcastOp> {
752 
753  LogicalResult matchAndRewrite(BroadcastOp op,
754  PatternRewriter &rewriter) const override {
755  SmallVector<int64_t, 8> foldedConstantShape;
756  SmallVector<Value, 8> newShapeOperands;
757  for (Value shape : op.getShapes()) {
758  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
759  SmallVector<int64_t, 8> newFoldedConstantShape;
761  foldedConstantShape,
762  llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
763  newFoldedConstantShape)) {
764  foldedConstantShape = newFoldedConstantShape;
765  continue;
766  }
767  }
768  newShapeOperands.push_back(shape);
769  }
770 
771  // Need at least two constant operands to fold anything.
772  if (op.getNumOperands() - newShapeOperands.size() < 2)
773  return failure();
774 
775  auto foldedConstantOperandsTy = RankedTensorType::get(
776  {static_cast<int64_t>(foldedConstantShape.size())},
777  rewriter.getIndexType());
778  newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
779  op.getLoc(), foldedConstantOperandsTy,
780  rewriter.getIndexTensorAttr(foldedConstantShape)));
781  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
782  newShapeOperands);
783  return success();
784  }
785 };
786 
787 template <typename OpTy>
788 struct CanonicalizeCastExtentTensorOperandsPattern
789  : public OpRewritePattern<OpTy> {
791 
792  LogicalResult matchAndRewrite(OpTy op,
793  PatternRewriter &rewriter) const override {
794  // Canonicalize operands.
795  bool anyChange = false;
796  auto canonicalizeOperand = [&](Value operand) -> Value {
797  if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
798  // Only eliminate the cast if it holds no shape information.
799  bool isInformationLoosingCast =
800  llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
801  if (isInformationLoosingCast) {
802  anyChange = true;
803  return castOp.getSource();
804  }
805  }
806  return operand;
807  };
808  auto newOperands = llvm::to_vector<8>(
809  llvm::map_range(op.getOperands(), canonicalizeOperand));
810 
811  // Rewrite op if any change required.
812  if (!anyChange)
813  return failure();
814  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
815  return success();
816  }
817 };
818 
819 struct BroadcastConcretizeResultTypePattern
820  : public OpRewritePattern<BroadcastOp> {
822 
823  LogicalResult matchAndRewrite(BroadcastOp op,
824  PatternRewriter &rewriter) const override {
825  // Only concretize dynamic extent tensor result types.
826  auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
827  if (!resultTy || !resultTy.isDynamicDim(0))
828  return failure();
829 
830  // Infer resulting shape rank if possible.
831  int64_t maxRank = 0;
832  for (Value shape : op.getShapes()) {
833  if (auto extentTensorTy =
834  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
835  // Cannot infer resulting shape rank if any operand is dynamically
836  // ranked.
837  if (extentTensorTy.isDynamicDim(0))
838  return failure();
839  maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
840  }
841  }
842 
843  auto newOp = rewriter.create<BroadcastOp>(
844  op.getLoc(), getExtentTensorType(getContext(), maxRank),
845  op.getShapes());
846  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
847  return success();
848  }
849 };
850 } // namespace
851 
852 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
853  MLIRContext *context) {
854  patterns.add<BroadcastConcretizeResultTypePattern,
855  BroadcastFoldConstantOperandsPattern,
856  BroadcastForwardSingleOperandPattern,
857  CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
858  RemoveDuplicateOperandsPattern<BroadcastOp>,
859  RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
860 }
861 
862 //===----------------------------------------------------------------------===//
863 // ConcatOp
864 //===----------------------------------------------------------------------===//
865 
866 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
867  if (!adaptor.getLhs() || !adaptor.getRhs())
868  return nullptr;
869  auto lhsShape = llvm::to_vector<6>(
870  llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
871  auto rhsShape = llvm::to_vector<6>(
872  llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
873  SmallVector<int64_t, 6> resultShape;
874  resultShape.append(lhsShape.begin(), lhsShape.end());
875  resultShape.append(rhsShape.begin(), rhsShape.end());
876  Builder builder(getContext());
877  return builder.getIndexTensorAttr(resultShape);
878 }
879 
880 //===----------------------------------------------------------------------===//
881 // ConstShapeOp
882 //===----------------------------------------------------------------------===//
883 
885  p << " ";
886  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
887  p << "[";
888  interleaveComma(getShape().getValues<int64_t>(), p);
889  p << "] : ";
890  p.printType(getType());
891 }
892 
893 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
894  if (parser.parseOptionalAttrDict(result.attributes))
895  return failure();
896  // We piggy-back on ArrayAttr parsing, though we don't internally store the
897  // shape as an ArrayAttr.
898  // TODO: Implement custom parser and maybe make syntax a bit more concise.
899  Attribute extentsRaw;
900  NamedAttrList dummy;
901  if (parser.parseAttribute(extentsRaw, "dummy", dummy))
902  return failure();
903  auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
904  if (!extentsArray)
905  return failure();
907  for (Attribute extent : extentsArray) {
908  IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
909  if (!attr)
910  return failure();
911  ints.push_back(attr.getInt());
912  }
913  Builder &builder = parser.getBuilder();
914  result.addAttribute("shape", builder.getIndexTensorAttr(ints));
915  Type resultTy;
916  if (parser.parseColonType(resultTy))
917  return failure();
918  result.types.push_back(resultTy);
919  return success();
920 }
921 
922 OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
923 
924 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
925  MLIRContext *context) {
926  patterns.add<TensorCastConstShape>(context);
927 }
928 
929 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
930  MLIRContext *context, std::optional<Location> location,
931  ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
932  Builder b(context);
933  const Properties prop = adaptor.getProperties();
934  inferredReturnTypes.assign({RankedTensorType::get(
935  {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
936  return success();
937 }
938 
939 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
940  TypeRange r) {
941  if (l.size() != 1 || r.size() != 1)
942  return false;
943 
944  Type lhs = l.front();
945  Type rhs = r.front();
946 
947  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
948  // Shape type is compatible with all other valid return types.
949  return true;
950  return lhs == rhs;
951 }
952 
953 //===----------------------------------------------------------------------===//
954 // CstrBroadcastableOp
955 //===----------------------------------------------------------------------===//
956 
957 void CstrBroadcastableOp::getCanonicalizationPatterns(
959  // Canonicalization patterns have overlap with the considerations during
960  // folding in case additional shape information is inferred at some point that
961  // does not result in folding.
962  patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
963  CstrBroadcastableEqOps,
964  RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
965  RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
966 }
967 
968 // Return true if there is exactly one attribute not representing a scalar
969 // broadcast.
971  bool nonScalarSeen = false;
972  for (Attribute a : attributes) {
973  if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
974  if (nonScalarSeen)
975  return false;
976  nonScalarSeen = true;
977  }
978  }
979  return true;
980 }
981 
982 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
983  // No broadcasting is needed if all operands but one are scalar.
984  if (hasAtMostSingleNonScalar(adaptor.getShapes()))
985  return BoolAttr::get(getContext(), true);
986 
987  if ([&] {
989  for (const auto &operand : adaptor.getShapes()) {
990  if (!operand)
991  return false;
992  extents.push_back(llvm::to_vector<6>(
993  llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
994  }
996  }())
997  return BoolAttr::get(getContext(), true);
998 
999  // Lastly, see if folding can be completed based on what constraints are known
1000  // on the input shapes.
1001  if ([&] {
1003  for (auto shapeValue : getShapes()) {
1004  extents.emplace_back();
1005  if (failed(getShapeVec(shapeValue, extents.back())))
1006  return false;
1007  }
1009  }())
1010  return BoolAttr::get(getContext(), true);
1011 
1012  // Because a failing witness result here represents an eventual assertion
1013  // failure, we do not replace it with a constant witness.
1014  return nullptr;
1015 }
1016 
1017 LogicalResult CstrBroadcastableOp::verify() {
1018  // Ensure that CstrBroadcastableOp contains at least two operands
1019  if (getNumOperands() < 2)
1020  return emitOpError("required at least 2 input shapes");
1021  return success();
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // CstrEqOp
1026 //===----------------------------------------------------------------------===//
1027 
1028 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1029  MLIRContext *context) {
1030  // If inputs are equal, return passing witness
1031  patterns.add<CstrEqEqOps>(context);
1032 }
1033 
1034 OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1035  if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1036  return a && a == adaptor.getShapes().front();
1037  }))
1038  return BoolAttr::get(getContext(), true);
1039 
1040  // Because a failing witness result here represents an eventual assertion
1041  // failure, we do not try to replace it with a constant witness. Similarly, we
1042  // cannot if there are any non-const inputs.
1043  return nullptr;
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // ConstSizeOp
1048 //===----------------------------------------------------------------------===//
1049 
1050 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1051  int64_t value) {
1052  build(builder, result, builder.getIndexAttr(value));
1053 }
1054 
1055 OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1056 
1057 void ConstSizeOp::getAsmResultNames(
1058  llvm::function_ref<void(Value, StringRef)> setNameFn) {
1059  SmallString<4> buffer;
1060  llvm::raw_svector_ostream os(buffer);
1061  os << "c" << getValue();
1062  setNameFn(getResult(), os.str());
1063 }
1064 
1065 //===----------------------------------------------------------------------===//
1066 // ConstWitnessOp
1067 //===----------------------------------------------------------------------===//
1068 
1069 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // CstrRequireOp
1073 //===----------------------------------------------------------------------===//
1074 
1075 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1076  return adaptor.getPred();
1077 }
1078 
1079 //===----------------------------------------------------------------------===//
1080 // DimOp
1081 //===----------------------------------------------------------------------===//
1082 
1083 std::optional<int64_t> DimOp::getConstantIndex() {
1084  if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1085  return constSizeOp.getValue().getLimitedValue();
1086  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1087  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1088  return std::nullopt;
1089 }
1090 
1091 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1092  Type valType = getValue().getType();
1093  auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1094  if (!valShapedType || !valShapedType.hasRank())
1095  return nullptr;
1096  std::optional<int64_t> index = getConstantIndex();
1097  if (!index.has_value())
1098  return nullptr;
1099  if (index.value() < 0 || index.value() >= valShapedType.getRank())
1100  return nullptr;
1101  auto extent = valShapedType.getDimSize(*index);
1102  if (ShapedType::isDynamic(extent))
1103  return nullptr;
1104  return IntegerAttr::get(IndexType::get(getContext()), extent);
1105 }
1106 
1107 LogicalResult mlir::shape::DimOp::inferReturnTypes(
1108  MLIRContext *context, std::optional<Location> location,
1109  DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1110  inferredReturnTypes.assign({adaptor.getIndex().getType()});
1111  return success();
1112 }
1113 
1114 bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1115  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1116 }
1117 
1118 //===----------------------------------------------------------------------===//
1119 // DivOp
1120 //===----------------------------------------------------------------------===//
1121 
1122 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1123  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1124  if (!lhs)
1125  return nullptr;
1126  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1127  if (!rhs)
1128  return nullptr;
1129 
1130  // Division in APInt does not follow floor(lhs, rhs) when the result is
1131  // negative. Rather, APInt rounds toward zero.
1132  APInt quotient, remainder;
1133  APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1134  if (quotient.isNegative() && !remainder.isZero()) {
1135  quotient -= 1;
1136  }
1137 
1138  Type indexTy = IndexType::get(getContext());
1139  return IntegerAttr::get(indexTy, quotient);
1140 }
1141 
1142 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1143  MLIRContext *context, std::optional<Location> location,
1144  DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1145  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1146  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1147  inferredReturnTypes.assign({SizeType::get(context)});
1148  else
1149  inferredReturnTypes.assign({IndexType::get(context)});
1150  return success();
1151 }
1152 
1153 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1154  // SizeType is compatible with IndexType.
1155  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1156 }
1157 
1158 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1159 
1160 //===----------------------------------------------------------------------===//
1161 // ShapeEqOp
1162 //===----------------------------------------------------------------------===//
1163 
1164 OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1165  bool allSame = true;
1166  if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1167  return {};
1168  for (Attribute operand : adaptor.getShapes().drop_front()) {
1169  if (!operand)
1170  return {};
1171  allSame = allSame && operand == adaptor.getShapes().front();
1172  }
1173  return BoolAttr::get(getContext(), allSame);
1174 }
1175 
1176 //===----------------------------------------------------------------------===//
1177 // IndexToSizeOp
1178 //===----------------------------------------------------------------------===//
1179 
1180 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1181  // Constant values of both types, `shape.size` and `index`, are represented as
1182  // `IntegerAttr`s which makes constant folding simple.
1183  if (Attribute arg = adaptor.getArg())
1184  return arg;
1185  return {};
1186 }
1187 
1188 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1189  MLIRContext *context) {
1190  patterns.add<SizeToIndexToSizeCanonicalization>(context);
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // FromExtentsOp
1195 //===----------------------------------------------------------------------===//
1196 
1197 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1198  if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
1199  return nullptr;
1200  SmallVector<int64_t, 6> extents;
1201  for (auto attr : adaptor.getExtents())
1202  extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1203  Builder builder(getContext());
1204  return builder.getIndexTensorAttr(extents);
1205 }
1206 
1207 //===----------------------------------------------------------------------===//
1208 // FunctionLibraryOp
1209 //===----------------------------------------------------------------------===//
1210 
1211 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1212  StringRef name) {
1213  result.attributes.push_back(builder.getNamedAttr(
1215 }
1216 
1217 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1218  auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1219  getMapping().get(op->getName().getIdentifier()));
1220  if (!attr)
1221  return nullptr;
1222  return lookupSymbol<FuncOp>(attr);
1223 }
1224 
1225 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1226  OperationState &result) {
1227  // Parse the op name.
1228  StringAttr nameAttr;
1229  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1230  result.attributes))
1231  return failure();
1232 
1233  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1234  return failure();
1235 
1236  auto *bodyRegion = result.addRegion();
1237  if (parser.parseRegion(*bodyRegion))
1238  return failure();
1239 
1240  if (parser.parseKeyword("mapping"))
1241  return failure();
1242 
1243  DictionaryAttr mappingAttr;
1244  if (parser.parseAttribute(mappingAttr,
1245  parser.getBuilder().getType<NoneType>(), "mapping",
1246  result.attributes))
1247  return failure();
1248  return success();
1249 }
1250 
1252  p << ' ';
1253  p.printSymbolName(getName());
1255  (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1256  p << ' ';
1257  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1258  /*printBlockTerminators=*/false);
1259  p << " mapping ";
1260  p.printAttributeWithoutType(getMappingAttr());
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // FuncOp
1265 //===----------------------------------------------------------------------===//
1266 
1267 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1268  ArrayRef<NamedAttribute> attrs) {
1269  OpBuilder builder(location->getContext());
1270  OperationState state(location, getOperationName());
1271  FuncOp::build(builder, state, name, type, attrs);
1272  return cast<FuncOp>(Operation::create(state));
1273 }
1274 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1276  SmallVector<NamedAttribute, 8> attrRef(attrs);
1277  return create(location, name, type, llvm::ArrayRef(attrRef));
1278 }
1279 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1281  ArrayRef<DictionaryAttr> argAttrs) {
1282  FuncOp func = create(location, name, type, attrs);
1283  func.setAllArgAttrs(argAttrs);
1284  return func;
1285 }
1286 
1287 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1288  FunctionType type, ArrayRef<NamedAttribute> attrs,
1289  ArrayRef<DictionaryAttr> argAttrs) {
1290  state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1291  builder.getStringAttr(name));
1292  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1293  TypeAttr::get(type));
1294  state.attributes.append(attrs.begin(), attrs.end());
1295  state.addRegion();
1296 
1297  if (argAttrs.empty())
1298  return;
1299  assert(type.getNumInputs() == argAttrs.size());
1301  builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
1302  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1303 }
1304 
1305 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1306  auto buildFuncType =
1307  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1309  std::string &) { return builder.getFunctionType(argTypes, results); };
1310 
1312  parser, result, /*allowVariadic=*/false,
1313  getFunctionTypeAttrName(result.name), buildFuncType,
1314  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1315 }
1316 
1317 void FuncOp::print(OpAsmPrinter &p) {
1319  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1320  getArgAttrsAttrName(), getResAttrsAttrName());
1321 }
1322 
1323 //===----------------------------------------------------------------------===//
1324 // GetExtentOp
1325 //===----------------------------------------------------------------------===//
1326 
1327 std::optional<int64_t> GetExtentOp::getConstantDim() {
1328  if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1329  return constSizeOp.getValue().getLimitedValue();
1330  if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1331  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1332  return std::nullopt;
1333 }
1334 
1335 OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1336  auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1337  if (!elements)
1338  return nullptr;
1339  std::optional<int64_t> dim = getConstantDim();
1340  if (!dim.has_value())
1341  return nullptr;
1342  if (dim.value() >= elements.getNumElements())
1343  return nullptr;
1344  return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1345 }
1346 
1347 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1348  int64_t dim) {
1349  auto loc = result.location;
1350  auto dimAttr = builder.getIndexAttr(dim);
1351  if (llvm::isa<ShapeType>(shape.getType())) {
1352  Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1353  build(builder, result, builder.getType<SizeType>(), shape, dim);
1354  } else {
1355  Value dim =
1356  builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1357  build(builder, result, builder.getIndexType(), shape, dim);
1358  }
1359 }
1360 
1361 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1362  MLIRContext *context, std::optional<Location> location,
1363  GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1364  inferredReturnTypes.assign({IndexType::get(context)});
1365  return success();
1366 }
1367 
1368 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1369  TypeRange r) {
1370  // SizeType is compatible with IndexType.
1371  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1372 }
1373 
1374 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1375 
1376 //===----------------------------------------------------------------------===//
1377 // IsBroadcastableOp
1378 //===----------------------------------------------------------------------===//
1379 
1380 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1381  MLIRContext *context) {
1382  patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1383 }
1384 
1385 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1386  // Can always broadcast fewer than two shapes.
1387  if (adaptor.getShapes().size() < 2) {
1388  return BoolAttr::get(getContext(), true);
1389  }
1390 
1391  return nullptr;
1392 }
1393 
1394 //===----------------------------------------------------------------------===//
1395 // MeetOp
1396 //===----------------------------------------------------------------------===//
1397 
1398 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1399  MLIRContext *context, std::optional<Location> location,
1400  MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1401  if (adaptor.getOperands().empty())
1402  return failure();
1403 
1404  auto isShapeType = [](Type arg) {
1405  if (llvm::isa<ShapeType>(arg))
1406  return true;
1407  return isExtentTensorType(arg);
1408  };
1409 
1410  ValueRange::type_range types = adaptor.getOperands().getTypes();
1411  Type acc = types.front();
1412  for (auto t : drop_begin(types)) {
1413  Type l = acc, r = t;
1414  if (!llvm::isa<ShapeType, SizeType>(l))
1415  std::swap(l, r);
1416 
1417  // Handle sizes, propagate error type if present.
1418  if (llvm::isa<SizeType>(l)) {
1419  if (llvm::isa<SizeType, IndexType>(r))
1420  acc = l;
1421  else
1422  return emitOptionalError(location, "requires all sizes or shapes");
1423  } else if (llvm::isa<IndexType>(l)) {
1424  if (llvm::isa<IndexType>(r))
1425  acc = r;
1426  else
1427  return emitOptionalError(location, "requires all sizes or shapes");
1428  } else if (llvm::isa<ShapeType>(l)) {
1429  // Handle shapes, propagate error type if present.
1430  if (isShapeType(r))
1431  acc = l;
1432  else
1433  return emitOptionalError(location, "requires all sizes or shapes");
1434  } else if (isExtentTensorType(l)) {
1435  auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1436  auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1437  if (ShapedType::isDynamic(rank1))
1438  acc = l;
1439  else if (ShapedType::isDynamic(rank2))
1440  acc = r;
1441  else if (rank1 != rank2)
1442  return emitOptionalError(location, "unequal shape cardinality");
1443  else
1444  acc = l;
1445  }
1446  }
1447  inferredReturnTypes.assign({acc});
1448  return success();
1449 }
1450 
1451 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1452  if (l.size() != 1 || r.size() != 1)
1453  return false;
1454  if (l == r)
1455  return true;
1456 
1457  Type lhs = l.front();
1458  Type rhs = r.front();
1459 
1460  if (!llvm::isa<ShapeType, SizeType>(lhs))
1461  std::swap(lhs, rhs);
1462 
1463  if (llvm::isa<SizeType>(lhs))
1464  return llvm::isa<SizeType, IndexType>(rhs);
1465  if (llvm::isa<ShapeType>(lhs))
1466  return llvm::isa<ShapeType, TensorType>(rhs);
1467 
1468  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1469  return true;
1470  return false;
1471 }
1472 
1473 //===----------------------------------------------------------------------===//
1474 // RankOp
1475 //===----------------------------------------------------------------------===//
1476 
1477 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1478  auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1479  if (!shape)
1480  return {};
1481  int64_t rank = shape.getNumElements();
1482  Builder builder(getContext());
1483  return builder.getIndexAttr(rank);
1484 }
1485 
1486 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1487 /// Constant folding fails in cases where only the rank is constant, not the
1488 /// shape itself.
1489 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1490 ///
1491 /// Example:
1492 ///
1493 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1494 /// %rank = shape.rank %shape
1495 ///
1496 /// becomes
1497 ///
1498 /// %rank = shape.const_size 3
1499 
1500 namespace {
1501 struct RankShapeOfCanonicalizationPattern
1502  : public OpRewritePattern<shape::RankOp> {
1504 
1505  LogicalResult matchAndRewrite(shape::RankOp op,
1506  PatternRewriter &rewriter) const override {
1507  auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1508  if (!shapeOfOp)
1509  return failure();
1510  auto rankedTensorType =
1511  llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1512  if (!rankedTensorType)
1513  return failure();
1514  int64_t rank = rankedTensorType.getRank();
1515  if (llvm::isa<IndexType>(op.getType())) {
1516  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1517  rank);
1518  } else if (llvm::isa<shape::SizeType>(op.getType())) {
1519  rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1520  } else {
1521  return failure();
1522  }
1523  return success();
1524  }
1525 };
1526 } // namespace
1527 
1528 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1529  MLIRContext *context) {
1530  patterns.add<RankShapeOfCanonicalizationPattern>(context);
1531 }
1532 
1533 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1534  MLIRContext *context, std::optional<Location> location,
1535  RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1536  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1537  inferredReturnTypes.assign({SizeType::get(context)});
1538  else
1539  inferredReturnTypes.assign({IndexType::get(context)});
1540  return success();
1541 }
1542 
1543 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1544  // SizeType is compatible with IndexType.
1545  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1546 }
1547 
1548 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1549 
1550 //===----------------------------------------------------------------------===//
1551 // NumElementsOp
1552 //===----------------------------------------------------------------------===//
1553 
1554 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1555 
1556  // Fold only when argument constant.
1557  Attribute shape = adaptor.getShape();
1558  if (!shape)
1559  return {};
1560 
1561  APInt product(64, 1);
1562  for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1563  product *= value;
1564  Builder builder(getContext());
1565  return builder.getIndexAttr(product.getLimitedValue());
1566 }
1567 
1568 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1569  MLIRContext *context, std::optional<Location> location,
1570  NumElementsOp::Adaptor adaptor,
1571  SmallVectorImpl<Type> &inferredReturnTypes) {
1572  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1573  inferredReturnTypes.assign({SizeType::get(context)});
1574  else
1575  inferredReturnTypes.assign({IndexType::get(context)});
1576  return success();
1577 }
1578 
1579 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1580  TypeRange r) {
1581  // SizeType is compatible with IndexType.
1582  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1583 }
1584 
1585 LogicalResult shape::NumElementsOp::verify() {
1586  return verifySizeOrIndexOp(*this);
1587 }
1588 
1589 //===----------------------------------------------------------------------===//
1590 // MaxOp
1591 //===----------------------------------------------------------------------===//
1592 
1593 OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1594  // If operands are equal, just propagate one.
1595  if (getLhs() == getRhs())
1596  return getLhs();
1597  return nullptr;
1598 }
1599 
1600 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1601  MLIRContext *context, std::optional<Location> location,
1602  MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1603  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1604  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1605  else
1606  inferredReturnTypes.assign({SizeType::get(context)});
1607  return success();
1608 }
1609 
1610 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1611  if (l.size() != 1 || r.size() != 1)
1612  return false;
1613  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1614  return true;
1615  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1616  return true;
1617  return false;
1618 }
1619 
1620 //===----------------------------------------------------------------------===//
1621 // MinOp
1622 //===----------------------------------------------------------------------===//
1623 
1624 OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1625  // If operands are equal, just propagate one.
1626  if (getLhs() == getRhs())
1627  return getLhs();
1628  return nullptr;
1629 }
1630 
1631 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1632  MLIRContext *context, std::optional<Location> location,
1633  MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1634  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1635  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1636  else
1637  inferredReturnTypes.assign({SizeType::get(context)});
1638  return success();
1639 }
1640 
1641 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1642  if (l.size() != 1 || r.size() != 1)
1643  return false;
1644  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1645  return true;
1646  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1647  return true;
1648  return false;
1649 }
1650 
1651 //===----------------------------------------------------------------------===//
1652 // MulOp
1653 //===----------------------------------------------------------------------===//
1654 
1655 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1656  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1657  if (!lhs)
1658  return nullptr;
1659  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1660  if (!rhs)
1661  return nullptr;
1662  APInt folded = lhs.getValue() * rhs.getValue();
1663  Type indexTy = IndexType::get(getContext());
1664  return IntegerAttr::get(indexTy, folded);
1665 }
1666 
1667 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1668  MLIRContext *context, std::optional<Location> location,
1669  MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1670  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1671  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1672  inferredReturnTypes.assign({SizeType::get(context)});
1673  else
1674  inferredReturnTypes.assign({IndexType::get(context)});
1675  return success();
1676 }
1677 
1678 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1679  // SizeType is compatible with IndexType.
1680  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1681 }
1682 
1683 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1684 
1685 //===----------------------------------------------------------------------===//
1686 // ShapeOfOp
1687 //===----------------------------------------------------------------------===//
1688 
1689 namespace {
1690 /// Replace shape_of(x) where x has a constant shape with a const_shape op.
1691 struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1693 
1694  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1695  PatternRewriter &rewriter) const override {
1696  auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1697  if (!type || !type.hasStaticShape())
1698  return failure();
1699  Location loc = op.getLoc();
1700  Value constShape =
1701  rewriter
1702  .create<ConstShapeOp>(loc,
1703  rewriter.getIndexTensorAttr(type.getShape()))
1704  .getResult();
1705  if (constShape.getType() != op.getResult().getType())
1706  constShape = rewriter.create<tensor::CastOp>(
1707  loc, op.getResult().getType(), constShape);
1708  rewriter.replaceOp(op, constShape);
1709  return success();
1710  }
1711 };
1712 
1713 // Canonicalize
1714 //
1715 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1716 // %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1717 //
1718 // to
1719 //
1720 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1721 // %1 = %shape
1722 //
1723 struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
1725 
1726  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1727  PatternRewriter &rewriter) const override {
1728  auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1729  if (!tensorReshapeOp)
1730  return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
1731  if (!isa<TensorType>(op.getType()))
1732  return rewriter.notifyMatchFailure(op, "result is not a tensor");
1733 
1734  // Operand 'shape' of 'tensor.reshape' may now be used as the result of
1735  // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1736  // formed IR, it may not be identical (dynamically vs statically shaped),
1737  // in which case it needs to be cast first.
1738  Value shape = tensorReshapeOp.getShape();
1739  if (op.getType() != shape.getType())
1740  shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
1741 
1742  rewriter.replaceOp(op, shape);
1743  return success();
1744  }
1745 };
1746 
1747 // Canonicalize
1748 // ```
1749 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1750 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1751 // ```
1752 // to
1753 // ```
1754 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1755 // ```
1756 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1758 
1759  LogicalResult matchAndRewrite(tensor::CastOp op,
1760  PatternRewriter &rewriter) const override {
1761  auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1762  if (!ty || ty.getRank() != 1)
1763  return failure();
1764 
1765  auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1766  if (!shapeOfOp)
1767  return failure();
1768 
1769  // Argument type must be ranked and must not conflict.
1770  auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1771  if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1772  return failure();
1773 
1774  rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1775  return success();
1776  }
1777 };
1778 } // namespace
1779 
1780 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1781  MLIRContext *context) {
1782  patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1783  ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1784  context);
1785 }
1786 
1787 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1788  MLIRContext *context, std::optional<Location> location,
1789  ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1790  if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1791  inferredReturnTypes.assign({ShapeType::get(context)});
1792  else {
1793  auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1794  int64_t rank =
1795  shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1796  Type indexTy = IndexType::get(context);
1797  Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1798  inferredReturnTypes.assign({extentTensorTy});
1799  }
1800  return success();
1801 }
1802 
1803 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1804  if (l.size() != 1 || r.size() != 1)
1805  return false;
1806  if (l == r)
1807  return true;
1808 
1809  Type lhs = l.front();
1810  Type rhs = r.front();
1811 
1812  if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1813  !llvm::isa<ShapeType, ShapedType>(rhs))
1814  return false;
1815 
1816  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1817  // Shape type is compatible with all other valid return types.
1818  return true;
1819 
1820  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1821  return true;
1822  return false;
1823 }
1824 
1825 LogicalResult shape::ShapeOfOp::verify() {
1826  return verifyShapeOrExtentTensorOp(*this);
1827 }
1828 
1829 //===----------------------------------------------------------------------===//
1830 // SizeToIndexOp
1831 //===----------------------------------------------------------------------===//
1832 
1833 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1834  // Constant values of both types, `shape.size` and `index`, are represented as
1835  // `IntegerAttr`s which makes constant folding simple.
1836  if (Attribute arg = adaptor.getArg())
1837  return arg;
1838  return OpFoldResult();
1839 }
1840 
1841 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1842  MLIRContext *context) {
1843  patterns.add<IndexToSizeToIndexCanonicalization>(context);
1844 }
1845 
1846 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1847  if (inputs.size() != 1 || outputs.size() != 1)
1848  return false;
1849  return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1850  llvm::isa<IndexType>(outputs[0]);
1851 }
1852 
1853 //===----------------------------------------------------------------------===//
1854 // YieldOp
1855 //===----------------------------------------------------------------------===//
1856 
1857 LogicalResult shape::YieldOp::verify() {
1858  auto *parentOp = (*this)->getParentOp();
1859  auto results = parentOp->getResults();
1860  auto operands = getOperands();
1861 
1862  if (parentOp->getNumResults() != getNumOperands())
1863  return emitOpError() << "number of operands does not match number of "
1864  "results of its parent";
1865  for (auto e : llvm::zip(results, operands))
1866  if (std::get<0>(e).getType() != std::get<1>(e).getType())
1867  return emitOpError() << "types mismatch between yield op and its parent";
1868 
1869  return success();
1870 }
1871 
1872 //===----------------------------------------------------------------------===//
1873 // SplitAtOp
1874 //===----------------------------------------------------------------------===//
1875 
1876 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1877  SmallVectorImpl<OpFoldResult> &results) {
1878  if (!adaptor.getOperand() || !adaptor.getIndex())
1879  return failure();
1880  auto shapeVec = llvm::to_vector<6>(
1881  llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1882  auto shape = llvm::ArrayRef(shapeVec);
1883  auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1884  // Verify that the split point is in the correct range.
1885  // TODO: Constant fold to an "error".
1886  int64_t rank = shape.size();
1887  if (-rank > splitPoint || splitPoint > rank)
1888  return failure();
1889  if (splitPoint < 0)
1890  splitPoint += shape.size();
1891  Builder builder(adaptor.getOperand().getContext());
1892  results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1893  results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1894  return success();
1895 }
1896 
1897 //===----------------------------------------------------------------------===//
1898 // ToExtentTensorOp
1899 //===----------------------------------------------------------------------===//
1900 
1901 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1902  if (!adaptor.getInput())
1903  return OpFoldResult();
1904  Builder builder(getContext());
1905  auto shape = llvm::to_vector<6>(
1906  llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1907  auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1908  builder.getIndexType());
1909  return DenseIntElementsAttr::get(type, shape);
1910 }
1911 
1912 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1913  if (inputs.size() != 1 || outputs.size() != 1)
1914  return false;
1915  if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1916  if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1917  inputTensor.getRank() != 1)
1918  return false;
1919  } else if (!llvm::isa<ShapeType>(inputs[0])) {
1920  return false;
1921  }
1922 
1923  TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1924  return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1925 }
1926 
1927 //===----------------------------------------------------------------------===//
1928 // ReduceOp
1929 //===----------------------------------------------------------------------===//
1930 
1931 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1932  ValueRange initVals) {
1933  OpBuilder::InsertionGuard g(builder);
1934  result.addOperands(shape);
1935  result.addOperands(initVals);
1936 
1937  Region *bodyRegion = result.addRegion();
1938  Block *bodyBlock = builder.createBlock(
1939  bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
1940 
1941  Type elementType;
1942  if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1943  elementType = tensorType.getElementType();
1944  else
1945  elementType = SizeType::get(builder.getContext());
1946  bodyBlock->addArgument(elementType, shape.getLoc());
1947 
1948  for (Value initVal : initVals) {
1949  bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
1950  result.addTypes(initVal.getType());
1951  }
1952 }
1953 
1954 LogicalResult ReduceOp::verify() {
1955  // Verify block arg types.
1956  Block &block = getRegion().front();
1957 
1958  // The block takes index, extent, and aggregated values as arguments.
1959  auto blockArgsCount = getInitVals().size() + 2;
1960  if (block.getNumArguments() != blockArgsCount)
1961  return emitOpError() << "ReduceOp body is expected to have "
1962  << blockArgsCount << " arguments";
1963 
1964  // The first block argument is the index and must always be of type `index`.
1965  if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1966  return emitOpError(
1967  "argument 0 of ReduceOp body is expected to be of IndexType");
1968 
1969  // The second block argument is the extent and must be of type `size` or
1970  // `index`, depending on whether the reduce operation is applied to a shape or
1971  // to an extent tensor.
1972  Type extentTy = block.getArgument(1).getType();
1973  if (llvm::isa<ShapeType>(getShape().getType())) {
1974  if (!llvm::isa<SizeType>(extentTy))
1975  return emitOpError("argument 1 of ReduceOp body is expected to be of "
1976  "SizeType if the ReduceOp operates on a ShapeType");
1977  } else {
1978  if (!llvm::isa<IndexType>(extentTy))
1979  return emitOpError(
1980  "argument 1 of ReduceOp body is expected to be of IndexType if the "
1981  "ReduceOp operates on an extent tensor");
1982  }
1983 
1984  for (const auto &type : llvm::enumerate(getInitVals()))
1985  if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1986  return emitOpError() << "type mismatch between argument "
1987  << type.index() + 2
1988  << " of ReduceOp body and initial value "
1989  << type.index();
1990  return success();
1991 }
1992 
1993 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1994  // Parse operands.
1996  Type shapeOrExtentTensorType;
1997  if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1999  parser.parseColonType(shapeOrExtentTensorType) ||
2000  parser.parseOptionalArrowTypeList(result.types))
2001  return failure();
2002 
2003  // Resolve operands.
2004  auto initVals = llvm::ArrayRef(operands).drop_front();
2005  if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
2006  result.operands) ||
2007  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2008  result.operands))
2009  return failure();
2010 
2011  // Parse the body.
2012  Region *body = result.addRegion();
2013  if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
2014  return failure();
2015 
2016  // Parse attributes.
2017  if (parser.parseOptionalAttrDict(result.attributes))
2018  return failure();
2019 
2020  return success();
2021 }
2022 
2023 void ReduceOp::print(OpAsmPrinter &p) {
2024  p << '(' << getShape() << ", " << getInitVals()
2025  << ") : " << getShape().getType();
2026  p.printOptionalArrowTypeList(getResultTypes());
2027  p << ' ';
2028  p.printRegion(getRegion());
2029  p.printOptionalAttrDict((*this)->getAttrs());
2030 }
2031 
2032 #define GET_OP_CLASSES
2033 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2034 
2035 #define GET_TYPEDEF_CLASSES
2036 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition: Shape.cpp:67
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition: Shape.cpp:970
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition: Shape.cpp:84
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition: Shape.cpp:97
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition: Shape.cpp:72
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:120
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:100
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:233
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:95
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:134
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:454
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:451
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
ValueTypeRange< ValueRange > type_range
Definition: ValueRange.h:410
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition: Traits.cpp:25
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
bool isExtentTensorType(Type)
Definition: Shape.cpp:45
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition: Shape.cpp:50
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:41
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.