1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <utility>
17 #include "mlir/Dialect/Traits.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/SetOperations.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/raw_ostream.h"
32 using namespace mlir;
33 using namespace mlir::shape;
35 #include "mlir/Dialect/Shape/IR/"
37 namespace {
38 #include ""
39 } // namespace
41 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
42  return RankedTensorType::get({rank}, IndexType::get(ctx));
43 }
46  auto ranked = llvm::dyn_cast<RankedTensorType>(type);
47  return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
48 }
50 LogicalResult shape::getShapeVec(Value input,
51  SmallVectorImpl<int64_t> &shapeValues) {
52  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
53  auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
54  if (!type.hasRank())
55  return failure();
56  llvm::append_range(shapeValues, type.getShape());
57  return success();
58  }
60  if (matchPattern(input, m_Constant(&attr))) {
61  llvm::append_range(shapeValues, attr.getValues<int64_t>());
62  return success();
63  }
64  return failure();
65 }
67 static bool isErrorPropagationPossible(TypeRange operandTypes) {
68  return llvm::any_of(operandTypes,
69  llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
70 }
72 static LogicalResult verifySizeOrIndexOp(Operation *op) {
73  assert(op != nullptr && op->getNumResults() == 1);
74  Type resultTy = op->getResultTypes().front();
76  if (!llvm::isa<SizeType>(resultTy))
77  return op->emitOpError()
78  << "if at least one of the operands can hold error values then "
79  "the result must be of type `size` to propagate them";
80  }
81  return success();
82 }
84 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
85  assert(op != nullptr && op->getNumResults() == 1);
86  Type resultTy = op->getResultTypes().front();
88  if (!llvm::isa<ShapeType>(resultTy))
89  return op->emitOpError()
90  << "if at least one of the operands can hold error values then "
91  "the result must be of type `shape` to propagate them";
92  }
93  return success();
94 }
96 template <typename... Ty>
97 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
98  return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
99 }
101 template <typename... Ty, typename... ranges>
102 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
103  return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
104 }
106 //===----------------------------------------------------------------------===//
107 // InlinerInterface
108 //===----------------------------------------------------------------------===//
110 namespace {
111 /// This class defines the interface for inlining shape dialect ops.
112 struct ShapeInlinerInterface : public DialectInlinerInterface {
115  // Returns true if the given region 'src' can be inlined into the region
116  // 'dest' that is attached to an operation registered to the current dialect.
117  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
118  IRMapping &) const final {
119  return true;
120  }
122  // Returns true if the given operation 'op', that is registered to this
123  // dialect, can be inlined into the region 'dest' that is attached to an
124  // operation registered to the current dialect.
125  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
126  IRMapping &) const final {
127  return true;
128  }
129 };
130 } // namespace
132 void ShapeDialect::initialize() {
133  addOperations<
134 #define GET_OP_LIST
135 #include "mlir/Dialect/Shape/IR/"
136  >();
137  addTypes<
138 #define GET_TYPEDEF_LIST
139 #include "mlir/Dialect/Shape/IR/"
140  >();
141  addInterfaces<ShapeInlinerInterface>();
142  // Allow unknown operations during prototyping and testing. As the dialect is
143  // still evolving it makes it simple to start with an unregistered ops and
144  // try different variants before actually defining the op.
145  allowUnknownOperations();
146  declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
147  AssumingYieldOp>();
148 }
151  Attribute value, Type type,
152  Location loc) {
153  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
154  return builder.create<ub::PoisonOp>(loc, type, poison);
156  if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
157  return builder.create<ConstShapeOp>(
158  loc, type, llvm::cast<DenseIntElementsAttr>(value));
159  if (llvm::isa<SizeType>(type))
160  return builder.create<ConstSizeOp>(loc, type,
161  llvm::cast<IntegerAttr>(value));
162  if (llvm::isa<WitnessType>(type))
163  return builder.create<ConstWitnessOp>(loc, type,
164  llvm::cast<BoolAttr>(value));
166  return arith::ConstantOp::materialize(builder, value, type, loc);
167 }
169 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
170  NamedAttribute attribute) {
171  // Verify shape.lib attribute.
172  if (attribute.getName() == "shape.lib") {
173  if (!op->hasTrait<OpTrait::SymbolTable>())
174  return op->emitError(
175  "shape.lib attribute may only be on op implementing SymbolTable");
177  if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
178  auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
179  if (!symbol)
180  return op->emitError("shape function library ")
181  << symbolRef << " not found";
182  return isa<shape::FunctionLibraryOp>(symbol)
183  ? success()
184  : op->emitError()
185  << symbolRef << " required to be shape function library";
186  }
188  if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
189  // Verify all entries are function libraries and mappings in libraries
190  // refer to unique ops.
192  for (auto it : arr) {
193  if (!llvm::isa<SymbolRefAttr>(it))
194  return op->emitError(
195  "only SymbolRefAttr allowed in shape.lib attribute array");
197  auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
198  SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
199  if (!shapeFnLib)
200  return op->emitError()
201  << it << " does not refer to FunctionLibraryOp";
202  for (auto mapping : shapeFnLib.getMapping()) {
203  if (!key.insert(mapping.getName()).second) {
204  return op->emitError("only one op to shape mapping allowed, found "
205  "multiple for `")
206  << mapping.getName() << "`";
207  }
208  }
209  }
210  return success();
211  }
213  return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
214  "allowed as shape.lib attribute");
215  }
216  return success();
217 }
219 //===----------------------------------------------------------------------===//
220 // AnyOp
221 //===----------------------------------------------------------------------===//
223 // TODO: Canonicalization should be implemented for shapes that can be
224 // determined through mixtures of the known dimensions of the inputs.
225 OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
226  // Only the last operand is checked because AnyOp is commutative.
227  if (adaptor.getInputs().back())
228  return adaptor.getInputs().back();
230  return nullptr;
231 }
233 //===----------------------------------------------------------------------===//
234 // AssumingOp
235 //===----------------------------------------------------------------------===//
237 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
238  result.regions.reserve(1);
239  Region *doRegion = result.addRegion();
241  auto &builder = parser.getBuilder();
243  if (parser.parseOperand(cond) ||
244  parser.resolveOperand(cond, builder.getType<WitnessType>(),
245  result.operands))
246  return failure();
248  // Parse optional results type list.
249  if (parser.parseOptionalArrowTypeList(result.types))
250  return failure();
252  // Parse the region and add a terminator if elided.
253  if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
254  return failure();
255  AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
257  // Parse the optional attribute list.
258  if (parser.parseOptionalAttrDict(result.attributes))
259  return failure();
260  return success();
261 }
264  bool yieldsResults = !getResults().empty();
266  p << " " << getWitness();
267  if (yieldsResults)
268  p << " -> (" << getResultTypes() << ")";
269  p << ' ';
270  p.printRegion(getDoRegion(),
271  /*printEntryBlockArgs=*/false,
272  /*printBlockTerminators=*/yieldsResults);
273  p.printOptionalAttrDict((*this)->getAttrs());
274 }
276 namespace {
277 // Removes AssumingOp with a passing witness and inlines the region.
278 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
281  LogicalResult matchAndRewrite(AssumingOp op,
282  PatternRewriter &rewriter) const override {
283  auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
284  if (!witness || !witness.getPassingAttr())
285  return failure();
287  AssumingOp::inlineRegionIntoParent(op, rewriter);
288  return success();
289  }
290 };
292 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
295  LogicalResult matchAndRewrite(AssumingOp op,
296  PatternRewriter &rewriter) const override {
297  Block *body = op.getBody();
298  auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
300  // Find used values.
301  SmallVector<Value, 4> newYieldOperands;
302  for (auto [opResult, yieldOperand] :
303  llvm::zip(op.getResults(), yieldOp.getOperands())) {
304  if (!opResult.getUses().empty()) {
305  newYieldOperands.push_back(yieldOperand);
306  }
307  }
309  // Rewrite only if redundant results exist.
310  if (newYieldOperands.size() == yieldOp->getNumOperands())
311  return failure();
313  // Replace yield op in the old assuming op's body and move the entire region
314  // to the new assuming op.
315  rewriter.setInsertionPointToEnd(body);
316  auto newYieldOp =
317  rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
318  rewriter.setInsertionPoint(op);
319  auto newOp = rewriter.create<AssumingOp>(
320  op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
321  newOp.getDoRegion().takeBody(op.getDoRegion());
323  // Use the new results to replace the previously used ones.
324  SmallVector<Value, 4> replacementValues;
325  auto src = newOp.getResults().begin();
326  for (auto it : op.getResults()) {
327  if (it.getUses().empty())
328  replacementValues.push_back(nullptr);
329  else
330  replacementValues.push_back(*src++);
331  }
332  rewriter.replaceOp(op, replacementValues);
333  return success();
334  }
335 };
336 } // namespace
338 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
339  MLIRContext *context) {
340  patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
341 }
343 // See RegionBranchOpInterface in Interfaces/
344 void AssumingOp::getSuccessorRegions(
346  // AssumingOp has unconditional control flow into the region and back to the
347  // parent, so return the correct RegionSuccessor purely based on the index
348  // being None or 0.
349  if (!point.isParent()) {
350  regions.push_back(RegionSuccessor(getResults()));
351  return;
352  }
354  regions.push_back(RegionSuccessor(&getDoRegion()));
355 }
357 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
358  PatternRewriter &rewriter) {
359  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
360  auto *assumingBlock = op.getBody();
361  auto initPosition = rewriter.getInsertionPoint();
362  auto *blockAfterAssuming =
363  rewriter.splitBlock(blockBeforeAssuming, initPosition);
365  // Remove the AssumingOp and AssumingYieldOp.
366  auto &yieldOp = assumingBlock->back();
367  rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
368  rewriter.replaceOp(op, yieldOp.getOperands());
369  rewriter.eraseOp(&yieldOp);
371  // Merge blocks together as there was no branching behavior from the
372  // AssumingOp.
373  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
374  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
375 }
377 void AssumingOp::build(
378  OpBuilder &builder, OperationState &result, Value witness,
380  OpBuilder::InsertionGuard g(builder);
382  result.addOperands(witness);
383  Region *bodyRegion = result.addRegion();
384  builder.createBlock(bodyRegion);
386  // Build body.
387  SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
388  builder.create<AssumingYieldOp>(result.location, yieldValues);
390  SmallVector<Type, 2> assumingTypes;
391  for (Value v : yieldValues)
392  assumingTypes.push_back(v.getType());
393  result.addTypes(assumingTypes);
394 }
396 //===----------------------------------------------------------------------===//
397 // AddOp
398 //===----------------------------------------------------------------------===//
400 LogicalResult mlir::shape::AddOp::inferReturnTypes(
401  MLIRContext *context, std::optional<Location> location,
402  AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
403  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
404  llvm::isa<SizeType>(adaptor.getRhs().getType()))
405  inferredReturnTypes.assign({SizeType::get(context)});
406  else
407  inferredReturnTypes.assign({IndexType::get(context)});
408  return success();
409 }
411 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
412  // SizeType is compatible with IndexType.
413  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
414 }
416 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
417  // add(x, 0) -> x
418  if (matchPattern(getRhs(), m_Zero()))
419  return getLhs();
421  return constFoldBinaryOp<IntegerAttr>(
422  adaptor.getOperands(),
423  [](APInt a, const APInt &b) { return std::move(a) + b; });
424 }
426 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
428 //===----------------------------------------------------------------------===//
429 // AssumingAllOp
430 //===----------------------------------------------------------------------===//
432 namespace {
434 // Merge multiple `shape.assuming_all` operations together.
435 //
436 // %0 = shape.assuming_all %w0, %w1
437 // %1 = shape.assuming_all %w2, %0
438 //
439 // to:
440 //
441 // %0 = shape.assuming_all %w0, %w2, %w2
442 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
445  LogicalResult matchAndRewrite(AssumingAllOp op,
446  PatternRewriter &rewriter) const override {
447  SmallVector<Value> operands;
449  for (Value operand : op.getInputs()) {
450  if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
451  operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
452  else
453  operands.push_back(operand);
454  }
456  // We didn't find any other `assuming_all` ops to merge with.
457  if (operands.size() == op.getNumOperands())
458  return failure();
460  // Replace with a new `assuming_all` operation with merged constraints.
461  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
462  return success();
463  }
464 };
466 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
467 // are subsumed by others.
468 //
469 // %0 = shape.cstr_broadcastable %shape0, %shape1
470 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
471 //
472 // %2 = shape.cstr_broadcastable %shape3, %shape4
473 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
474 //
475 // %4 = shape.assuming_all %0, %1, %2, %3
476 //
477 // to:
478 //
479 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
480 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
481 // %2 = shape.assuming_all %0, %1
482 //
483 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
484 // shapes [0, 1] are broadcastable too, and can be removed from the list of
485 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
486 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
487 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
490  LogicalResult matchAndRewrite(AssumingAllOp op,
491  PatternRewriter &rewriter) const override {
492  // Collect all `CstrBroadcastableOp` operands first.
494  for (Value operand : op.getInputs()) {
495  // TODO: Apply this optimization if some of the witnesses are not
496  // produced by the `cstr_broadcastable`.
497  auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
498  if (!broadcastable)
499  return failure();
501  operands.insert(broadcastable);
502  }
504  // Skip trivial `assuming_all` operations.
505  if (operands.size() <= 1)
506  return failure();
508  // Collect shapes checked by `cstr_broadcastable` operands.
510  for (auto cstr : operands) {
511  DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
512  shapes.emplace_back(cstr, std::move(shapesSet));
513  }
515  // Sort by the number of shape operands (larger to smaller).
516  llvm::sort(shapes, [](auto a, auto b) {
517  return a.first.getNumOperands() > b.first.getNumOperands();
518  });
520  // We start from the `cst_broadcastable` operations with largest number of
521  // shape operands, and remove redundant `cst_broadcastable` operations. We
522  // do this until we find a set of `cst_broadcastable` operations with
523  // non-overlapping constraints.
524  SmallVector<CstrBroadcastableOp> markedForErase;
526  for (unsigned i = 0; i < shapes.size(); ++i) {
527  auto isSubset = [&](auto pair) {
528  return llvm::set_is_subset(pair.second, shapes[i].second);
529  };
531  // Keep redundant `cstr_broadcastable` operations to be erased.
532  auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
533  for (auto *it0 = it; it0 < shapes.end(); ++it0)
534  markedForErase.push_back(it0->first);
535  shapes.erase(it, shapes.end());
536  }
538  // We didn't find any operands that could be removed.
539  if (markedForErase.empty())
540  return failure();
542  // Collect non-overlapping `cst_broadcastable` constraints.
543  SmallVector<Value> uniqueConstraints;
544  for (auto &shape : shapes)
545  uniqueConstraints.push_back(shape.first.getResult());
547  // Replace with a new `assuming_all` operation ...
548  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
550  // ... and maybe erase `cstr_broadcastable` ops without uses.
551  for (auto &op : markedForErase)
552  if (op->use_empty())
553  rewriter.eraseOp(op);
555  return success();
556  }
557 };
559 struct AssumingAllToCstrEqCanonicalization
560  : public OpRewritePattern<AssumingAllOp> {
563  LogicalResult matchAndRewrite(AssumingAllOp op,
564  PatternRewriter &rewriter) const override {
565  SmallVector<Value, 8> shapes;
566  for (Value w : op.getInputs()) {
567  auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
568  if (!cstrEqOp)
569  return failure();
570  bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
571  return llvm::is_contained(shapes, s);
572  });
573  if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
574  return failure();
575  shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
576  }
577  rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
578  return success();
579  }
580 };
582 template <typename OpTy>
583 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
586  LogicalResult matchAndRewrite(OpTy op,
587  PatternRewriter &rewriter) const override {
588  // Find unique operands.
589  SetVector<Value> unique(op.operand_begin(), op.operand_end());
591  // Reduce op to equivalent with unique operands.
592  if (unique.size() < op.getNumOperands()) {
593  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
594  unique.takeVector(), op->getAttrs());
595  return success();
596  }
598  return failure();
599  }
600 };
601 } // namespace
603 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
604  MLIRContext *context) {
605  patterns
606  .add<MergeAssumingAllOps, AssumingAllOneOp,
607  AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
608  RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
609 }
611 OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
612  // Iterate in reverse to first handle all constant operands. They are
613  // guaranteed to be the tail of the inputs because this is commutative.
614  for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
615  Attribute a = adaptor.getInputs()[idx];
616  // Cannot fold if any inputs are not constant;
617  if (!a)
618  return nullptr;
620  // We do not need to keep statically known values after handling them in
621  // this method.
622  getOperation()->eraseOperand(idx);
624  // Always false if any input is statically known false
625  if (!llvm::cast<BoolAttr>(a).getValue())
626  return a;
627  }
628  // If this is reached, all inputs were statically known passing.
629  return BoolAttr::get(getContext(), true);
630 }
632 LogicalResult AssumingAllOp::verify() {
633  // Ensure that AssumingAllOp contains at least one operand
634  if (getNumOperands() == 0)
635  return emitOpError("no operands specified");
637  return success();
638 }
640 //===----------------------------------------------------------------------===//
641 // BroadcastOp
642 //===----------------------------------------------------------------------===//
644 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
645  if (getShapes().size() == 1) {
646  // Otherwise, we need a cast which would be a canonicalization, not folding.
647  if (getShapes().front().getType() != getType())
648  return nullptr;
649  return getShapes().front();
650  }
652  // TODO: Support folding with more than 2 input shapes
653  if (getShapes().size() > 2)
654  return nullptr;
656  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
657  return nullptr;
658  auto lhsShape = llvm::to_vector<6>(
659  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
660  .getValues<int64_t>());
661  auto rhsShape = llvm::to_vector<6>(
662  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
663  .getValues<int64_t>());
664  SmallVector<int64_t, 6> resultShape;
666  // If the shapes are not compatible, we can't fold it.
667  // TODO: Fold to an "error".
668  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
669  return nullptr;
671  Builder builder(getContext());
672  return builder.getIndexTensorAttr(resultShape);
673 }
675 LogicalResult BroadcastOp::verify() {
676  return verifyShapeOrExtentTensorOp(*this);
677 }
679 namespace {
680 template <typename OpTy>
681 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
684  LogicalResult matchAndRewrite(OpTy op,
685  PatternRewriter &rewriter) const override {
686  auto isPotentiallyNonEmptyShape = [](Value shape) {
687  if (auto extentTensorTy =
688  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
689  if (extentTensorTy.getDimSize(0) == 0)
690  return false;
691  }
692  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
693  if (constShape.getShape().empty())
694  return false;
695  }
696  return true;
697  };
698  auto newOperands = llvm::to_vector<8>(
699  llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
701  // Reduce op to equivalent without empty shape operands.
702  if (newOperands.size() < op.getNumOperands()) {
703  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
704  op->getAttrs());
705  return success();
706  }
708  return failure();
709  }
710 };
712 struct BroadcastForwardSingleOperandPattern
713  : public OpRewritePattern<BroadcastOp> {
716  LogicalResult matchAndRewrite(BroadcastOp op,
717  PatternRewriter &rewriter) const override {
718  if (op.getNumOperands() != 1)
719  return failure();
720  Value replacement = op.getShapes().front();
722  // Insert cast if needed.
723  if (replacement.getType() != op.getType()) {
724  auto loc = op.getLoc();
725  if (llvm::isa<ShapeType>(op.getType())) {
726  replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
727  } else {
728  assert(!llvm::isa<ShapeType>(op.getType()) &&
729  !llvm::isa<ShapeType>(replacement.getType()) &&
730  "expect extent tensor cast");
731  replacement =
732  rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
733  }
734  }
736  rewriter.replaceOp(op, replacement);
737  return success();
738  }
739 };
741 struct BroadcastFoldConstantOperandsPattern
742  : public OpRewritePattern<BroadcastOp> {
745  LogicalResult matchAndRewrite(BroadcastOp op,
746  PatternRewriter &rewriter) const override {
747  SmallVector<int64_t, 8> foldedConstantShape;
748  SmallVector<Value, 8> newShapeOperands;
749  for (Value shape : op.getShapes()) {
750  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
751  SmallVector<int64_t, 8> newFoldedConstantShape;
753  foldedConstantShape,
754  llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
755  newFoldedConstantShape)) {
756  foldedConstantShape = newFoldedConstantShape;
757  continue;
758  }
759  }
760  newShapeOperands.push_back(shape);
761  }
763  // Need at least two constant operands to fold anything.
764  if (op.getNumOperands() - newShapeOperands.size() < 2)
765  return failure();
767  auto foldedConstantOperandsTy = RankedTensorType::get(
768  {static_cast<int64_t>(foldedConstantShape.size())},
769  rewriter.getIndexType());
770  newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
771  op.getLoc(), foldedConstantOperandsTy,
772  rewriter.getIndexTensorAttr(foldedConstantShape)));
773  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
774  newShapeOperands);
775  return success();
776  }
777 };
779 template <typename OpTy>
780 struct CanonicalizeCastExtentTensorOperandsPattern
781  : public OpRewritePattern<OpTy> {
784  LogicalResult matchAndRewrite(OpTy op,
785  PatternRewriter &rewriter) const override {
786  // Canonicalize operands.
787  bool anyChange = false;
788  auto canonicalizeOperand = [&](Value operand) -> Value {
789  if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
790  // Only eliminate the cast if it holds no shape information.
791  bool isInformationLoosingCast =
792  llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
793  if (isInformationLoosingCast) {
794  anyChange = true;
795  return castOp.getSource();
796  }
797  }
798  return operand;
799  };
800  auto newOperands = llvm::to_vector<8>(
801  llvm::map_range(op.getOperands(), canonicalizeOperand));
803  // Rewrite op if any change required.
804  if (!anyChange)
805  return failure();
806  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
807  return success();
808  }
809 };
811 struct BroadcastConcretizeResultTypePattern
812  : public OpRewritePattern<BroadcastOp> {
815  LogicalResult matchAndRewrite(BroadcastOp op,
816  PatternRewriter &rewriter) const override {
817  // Only concretize dynamic extent tensor result types.
818  auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
819  if (!resultTy || !resultTy.isDynamicDim(0))
820  return failure();
822  // Infer resulting shape rank if possible.
823  int64_t maxRank = 0;
824  for (Value shape : op.getShapes()) {
825  if (auto extentTensorTy =
826  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
827  // Cannot infer resulting shape rank if any operand is dynamically
828  // ranked.
829  if (extentTensorTy.isDynamicDim(0))
830  return failure();
831  maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
832  }
833  }
835  auto newOp = rewriter.create<BroadcastOp>(
836  op.getLoc(), getExtentTensorType(getContext(), maxRank),
837  op.getShapes());
838  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
839  return success();
840  }
841 };
842 } // namespace
844 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
845  MLIRContext *context) {
846  patterns.add<BroadcastConcretizeResultTypePattern,
847  BroadcastFoldConstantOperandsPattern,
848  BroadcastForwardSingleOperandPattern,
849  CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
850  RemoveDuplicateOperandsPattern<BroadcastOp>,
851  RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
852 }
854 //===----------------------------------------------------------------------===//
855 // ConcatOp
856 //===----------------------------------------------------------------------===//
858 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
859  if (!adaptor.getLhs() || !adaptor.getRhs())
860  return nullptr;
861  auto lhsShape = llvm::to_vector<6>(
862  llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
863  auto rhsShape = llvm::to_vector<6>(
864  llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
865  SmallVector<int64_t, 6> resultShape;
866  resultShape.append(lhsShape.begin(), lhsShape.end());
867  resultShape.append(rhsShape.begin(), rhsShape.end());
868  Builder builder(getContext());
869  return builder.getIndexTensorAttr(resultShape);
870 }
872 //===----------------------------------------------------------------------===//
873 // ConstShapeOp
874 //===----------------------------------------------------------------------===//
877  p << " ";
878  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
879  p << "[";
880  interleaveComma(getShape().getValues<int64_t>(), p);
881  p << "] : ";
882  p.printType(getType());
883 }
885 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
886  if (parser.parseOptionalAttrDict(result.attributes))
887  return failure();
888  // We piggy-back on ArrayAttr parsing, though we don't internally store the
889  // shape as an ArrayAttr.
890  // TODO: Implement custom parser and maybe make syntax a bit more concise.
891  Attribute extentsRaw;
892  NamedAttrList dummy;
893  if (parser.parseAttribute(extentsRaw, "dummy", dummy))
894  return failure();
895  auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
896  if (!extentsArray)
897  return failure();
899  for (Attribute extent : extentsArray) {
900  IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
901  if (!attr)
902  return failure();
903  ints.push_back(attr.getInt());
904  }
905  Builder &builder = parser.getBuilder();
906  result.addAttribute("shape", builder.getIndexTensorAttr(ints));
907  Type resultTy;
908  if (parser.parseColonType(resultTy))
909  return failure();
910  result.types.push_back(resultTy);
911  return success();
912 }
914 OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
916 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
917  MLIRContext *context) {
918  patterns.add<TensorCastConstShape>(context);
919 }
921 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
922  MLIRContext *context, std::optional<Location> location,
923  ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
924  Builder b(context);
925  const Properties prop = adaptor.getProperties();
926  inferredReturnTypes.assign({RankedTensorType::get(
927  {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
928  return success();
929 }
931 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
932  TypeRange r) {
933  if (l.size() != 1 || r.size() != 1)
934  return false;
936  Type lhs = l.front();
937  Type rhs = r.front();
939  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
940  // Shape type is compatible with all other valid return types.
941  return true;
942  return lhs == rhs;
943 }
945 //===----------------------------------------------------------------------===//
946 // CstrBroadcastableOp
947 //===----------------------------------------------------------------------===//
949 void CstrBroadcastableOp::getCanonicalizationPatterns(
950  RewritePatternSet &patterns, MLIRContext *context) {
951  // Canonicalization patterns have overlap with the considerations during
952  // folding in case additional shape information is inferred at some point that
953  // does not result in folding.
954  patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
955  CstrBroadcastableEqOps,
956  RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
957  RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
958 }
960 // Return true if there is exactly one attribute not representing a scalar
961 // broadcast.
963  bool nonScalarSeen = false;
964  for (Attribute a : attributes) {
965  if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
966  if (nonScalarSeen)
967  return false;
968  nonScalarSeen = true;
969  }
970  }
971  return true;
972 }
974 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
975  // No broadcasting is needed if all operands but one are scalar.
976  if (hasAtMostSingleNonScalar(adaptor.getShapes()))
977  return BoolAttr::get(getContext(), true);
979  if ([&] {
981  for (const auto &operand : adaptor.getShapes()) {
982  if (!operand)
983  return false;
984  extents.push_back(llvm::to_vector<6>(
985  llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
986  }
988  }())
989  return BoolAttr::get(getContext(), true);
991  // Lastly, see if folding can be completed based on what constraints are known
992  // on the input shapes.
993  if ([&] {
995  for (auto shapeValue : getShapes()) {
996  extents.emplace_back();
997  if (failed(getShapeVec(shapeValue, extents.back())))
998  return false;
999  }
1001  }())
1002  return BoolAttr::get(getContext(), true);
1004  // Because a failing witness result here represents an eventual assertion
1005  // failure, we do not replace it with a constant witness.
1006  return nullptr;
1007 }
1009 LogicalResult CstrBroadcastableOp::verify() {
1010  // Ensure that CstrBroadcastableOp contains at least two operands
1011  if (getNumOperands() < 2)
1012  return emitOpError("required at least 2 input shapes");
1013  return success();
1014 }
1016 //===----------------------------------------------------------------------===//
1017 // CstrEqOp
1018 //===----------------------------------------------------------------------===//
1020 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1021  MLIRContext *context) {
1022  // If inputs are equal, return passing witness
1023  patterns.add<CstrEqEqOps>(context);
1024 }
1026 OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1027  if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1028  return a && a == adaptor.getShapes().front();
1029  }))
1030  return BoolAttr::get(getContext(), true);
1032  // Because a failing witness result here represents an eventual assertion
1033  // failure, we do not try to replace it with a constant witness. Similarly, we
1034  // cannot if there are any non-const inputs.
1035  return nullptr;
1036 }
1038 //===----------------------------------------------------------------------===//
1039 // ConstSizeOp
1040 //===----------------------------------------------------------------------===//
1042 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1043  int64_t value) {
1044  build(builder, result, builder.getIndexAttr(value));
1045 }
1047 OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1049 void ConstSizeOp::getAsmResultNames(
1050  llvm::function_ref<void(Value, StringRef)> setNameFn) {
1051  SmallString<4> buffer;
1052  llvm::raw_svector_ostream os(buffer);
1053  os << "c" << getValue();
1054  setNameFn(getResult(), os.str());
1055 }
1057 //===----------------------------------------------------------------------===//
1058 // ConstWitnessOp
1059 //===----------------------------------------------------------------------===//
1061 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1063 //===----------------------------------------------------------------------===//
1064 // CstrRequireOp
1065 //===----------------------------------------------------------------------===//
1067 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1068  return adaptor.getPred();
1069 }
1071 //===----------------------------------------------------------------------===//
1072 // DimOp
1073 //===----------------------------------------------------------------------===//
1075 std::optional<int64_t> DimOp::getConstantIndex() {
1076  if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1077  return constSizeOp.getValue().getLimitedValue();
1078  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1079  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1080  return std::nullopt;
1081 }
1083 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1084  Type valType = getValue().getType();
1085  auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1086  if (!valShapedType || !valShapedType.hasRank())
1087  return nullptr;
1088  std::optional<int64_t> index = getConstantIndex();
1089  if (!index.has_value())
1090  return nullptr;
1091  if (index.value() < 0 || index.value() >= valShapedType.getRank())
1092  return nullptr;
1093  auto extent = valShapedType.getDimSize(*index);
1094  if (ShapedType::isDynamic(extent))
1095  return nullptr;
1096  return IntegerAttr::get(IndexType::get(getContext()), extent);
1097 }
1099 LogicalResult mlir::shape::DimOp::inferReturnTypes(
1100  MLIRContext *context, std::optional<Location> location,
1101  DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1102  inferredReturnTypes.assign({adaptor.getIndex().getType()});
1103  return success();
1104 }
1106 bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1107  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1108 }
1110 //===----------------------------------------------------------------------===//
1111 // DivOp
1112 //===----------------------------------------------------------------------===//
1114 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1115  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1116  if (!lhs)
1117  return nullptr;
1118  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1119  if (!rhs)
1120  return nullptr;
1122  // Division in APInt does not follow floor(lhs, rhs) when the result is
1123  // negative. Rather, APInt rounds toward zero.
1124  APInt quotient, remainder;
1125  APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1126  if (quotient.isNegative() && !remainder.isZero()) {
1127  quotient -= 1;
1128  }
1130  Type indexTy = IndexType::get(getContext());
1131  return IntegerAttr::get(indexTy, quotient);
1132 }
1134 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1135  MLIRContext *context, std::optional<Location> location,
1136  DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1137  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1138  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1139  inferredReturnTypes.assign({SizeType::get(context)});
1140  else
1141  inferredReturnTypes.assign({IndexType::get(context)});
1142  return success();
1143 }
1145 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1146  // SizeType is compatible with IndexType.
1147  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1148 }
1150 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1152 //===----------------------------------------------------------------------===//
1153 // ShapeEqOp
1154 //===----------------------------------------------------------------------===//
1156 OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1157  bool allSame = true;
1158  if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1159  return {};
1160  for (Attribute operand : adaptor.getShapes().drop_front()) {
1161  if (!operand)
1162  return {};
1163  allSame = allSame && operand == adaptor.getShapes().front();
1164  }
1165  return BoolAttr::get(getContext(), allSame);
1166 }
1168 //===----------------------------------------------------------------------===//
1169 // IndexToSizeOp
1170 //===----------------------------------------------------------------------===//
1172 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1173  // Constant values of both types, `shape.size` and `index`, are represented as
1174  // `IntegerAttr`s which makes constant folding simple.
1175  if (Attribute arg = adaptor.getArg())
1176  return arg;
1177  return {};
1178 }
1180 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1181  MLIRContext *context) {
1182  patterns.add<SizeToIndexToSizeCanonicalization>(context);
1183 }
1185 //===----------------------------------------------------------------------===//
1186 // FromExtentsOp
1187 //===----------------------------------------------------------------------===//
1189 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1190  if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
1191  return nullptr;
1192  SmallVector<int64_t, 6> extents;
1193  for (auto attr : adaptor.getExtents())
1194  extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1195  Builder builder(getContext());
1196  return builder.getIndexTensorAttr(extents);
1197 }
1199 //===----------------------------------------------------------------------===//
1200 // FunctionLibraryOp
1201 //===----------------------------------------------------------------------===//
1203 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1204  StringRef name) {
1205  result.attributes.push_back(builder.getNamedAttr(
1207 }
1209 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1210  auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1211  getMapping().get(op->getName().getIdentifier()));
1212  if (!attr)
1213  return nullptr;
1214  return lookupSymbol<FuncOp>(attr);
1215 }
1217 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1218  OperationState &result) {
1219  // Parse the op name.
1220  StringAttr nameAttr;
1221  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1222  result.attributes))
1223  return failure();
1225  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1226  return failure();
1228  auto *bodyRegion = result.addRegion();
1229  if (parser.parseRegion(*bodyRegion))
1230  return failure();
1232  if (parser.parseKeyword("mapping"))
1233  return failure();
1235  DictionaryAttr mappingAttr;
1236  if (parser.parseAttribute(mappingAttr,
1237  parser.getBuilder().getType<NoneType>(), "mapping",
1238  result.attributes))
1239  return failure();
1240  return success();
1241 }
1244  p << ' ';
1245  p.printSymbolName(getName());
1247  (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1248  p << ' ';
1249  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1250  /*printBlockTerminators=*/false);
1251  p << " mapping ";
1252  p.printAttributeWithoutType(getMappingAttr());
1253 }
1255 //===----------------------------------------------------------------------===//
1256 // FuncOp
1257 //===----------------------------------------------------------------------===//
1259 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1260  ArrayRef<NamedAttribute> attrs) {
1261  OpBuilder builder(location->getContext());
1262  OperationState state(location, getOperationName());
1263  FuncOp::build(builder, state, name, type, attrs);
1264  return cast<FuncOp>(Operation::create(state));
1265 }
1266 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1268  SmallVector<NamedAttribute, 8> attrRef(attrs);
1269  return create(location, name, type, llvm::ArrayRef(attrRef));
1270 }
1271 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1273  ArrayRef<DictionaryAttr> argAttrs) {
1274  FuncOp func = create(location, name, type, attrs);
1275  func.setAllArgAttrs(argAttrs);
1276  return func;
1277 }
1279 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1280  FunctionType type, ArrayRef<NamedAttribute> attrs,
1281  ArrayRef<DictionaryAttr> argAttrs) {
1282  state.addAttribute(FuncOp::getSymNameAttrName(,
1283  builder.getStringAttr(name));
1284  state.addAttribute(FuncOp::getFunctionTypeAttrName(,
1285  TypeAttr::get(type));
1286  state.attributes.append(attrs.begin(), attrs.end());
1287  state.addRegion();
1289  if (argAttrs.empty())
1290  return;
1291  assert(type.getNumInputs() == argAttrs.size());
1293  builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
1294  getArgAttrsAttrName(, getResAttrsAttrName(;
1295 }
1297 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1298  auto buildFuncType =
1299  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1301  std::string &) { return builder.getFunctionType(argTypes, results); };
1304  parser, result, /*allowVariadic=*/false,
1305  getFunctionTypeAttrName(, buildFuncType,
1306  getArgAttrsAttrName(, getResAttrsAttrName(;
1307 }
1309 void FuncOp::print(OpAsmPrinter &p) {
1311  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1312  getArgAttrsAttrName(), getResAttrsAttrName());
1313 }
1315 //===----------------------------------------------------------------------===//
1316 // GetExtentOp
1317 //===----------------------------------------------------------------------===//
1319 std::optional<int64_t> GetExtentOp::getConstantDim() {
1320  if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1321  return constSizeOp.getValue().getLimitedValue();
1322  if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1323  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1324  return std::nullopt;
1325 }
1327 OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1328  auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1329  if (!elements)
1330  return nullptr;
1331  std::optional<int64_t> dim = getConstantDim();
1332  if (!dim.has_value())
1333  return nullptr;
1334  if (dim.value() >= elements.getNumElements())
1335  return nullptr;
1336  return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1337 }
1339 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1340  int64_t dim) {
1341  auto loc = result.location;
1342  auto dimAttr = builder.getIndexAttr(dim);
1343  if (llvm::isa<ShapeType>(shape.getType())) {
1344  Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1345  build(builder, result, builder.getType<SizeType>(), shape, dim);
1346  } else {
1347  Value dim =
1348  builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1349  build(builder, result, builder.getIndexType(), shape, dim);
1350  }
1351 }
1353 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1354  MLIRContext *context, std::optional<Location> location,
1355  GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1356  inferredReturnTypes.assign({IndexType::get(context)});
1357  return success();
1358 }
1360 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1361  TypeRange r) {
1362  // SizeType is compatible with IndexType.
1363  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1364 }
1366 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1368 //===----------------------------------------------------------------------===//
1369 // IsBroadcastableOp
1370 //===----------------------------------------------------------------------===//
1372 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1373  MLIRContext *context) {
1374  patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1375 }
1377 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1378  // Can always broadcast fewer than two shapes.
1379  if (adaptor.getShapes().size() < 2) {
1380  return BoolAttr::get(getContext(), true);
1381  }
1383  return nullptr;
1384 }
1386 //===----------------------------------------------------------------------===//
1387 // MeetOp
1388 //===----------------------------------------------------------------------===//
1390 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1391  MLIRContext *context, std::optional<Location> location,
1392  MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1393  if (adaptor.getOperands().empty())
1394  return failure();
1396  auto isShapeType = [](Type arg) {
1397  if (llvm::isa<ShapeType>(arg))
1398  return true;
1399  return isExtentTensorType(arg);
1400  };
1402  ValueRange::type_range types = adaptor.getOperands().getTypes();
1403  Type acc = types.front();
1404  for (auto t : drop_begin(types)) {
1405  Type l = acc, r = t;
1406  if (!llvm::isa<ShapeType, SizeType>(l))
1407  std::swap(l, r);
1409  // Handle sizes, propagate error type if present.
1410  if (llvm::isa<SizeType>(l)) {
1411  if (llvm::isa<SizeType, IndexType>(r))
1412  acc = l;
1413  else
1414  return emitOptionalError(location, "requires all sizes or shapes");
1415  } else if (llvm::isa<IndexType>(l)) {
1416  if (llvm::isa<IndexType>(r))
1417  acc = r;
1418  else
1419  return emitOptionalError(location, "requires all sizes or shapes");
1420  } else if (llvm::isa<ShapeType>(l)) {
1421  // Handle shapes, propagate error type if present.
1422  if (isShapeType(r))
1423  acc = l;
1424  else
1425  return emitOptionalError(location, "requires all sizes or shapes");
1426  } else if (isExtentTensorType(l)) {
1427  auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1428  auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1429  if (ShapedType::isDynamic(rank1))
1430  acc = l;
1431  else if (ShapedType::isDynamic(rank2))
1432  acc = r;
1433  else if (rank1 != rank2)
1434  return emitOptionalError(location, "unequal shape cardinality");
1435  else
1436  acc = l;
1437  }
1438  }
1439  inferredReturnTypes.assign({acc});
1440  return success();
1441 }
1443 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1444  if (l.size() != 1 || r.size() != 1)
1445  return false;
1446  if (l == r)
1447  return true;
1449  Type lhs = l.front();
1450  Type rhs = r.front();
1452  if (!llvm::isa<ShapeType, SizeType>(lhs))
1453  std::swap(lhs, rhs);
1455  if (llvm::isa<SizeType>(lhs))
1456  return llvm::isa<SizeType, IndexType>(rhs);
1457  if (llvm::isa<ShapeType>(lhs))
1458  return llvm::isa<ShapeType, TensorType>(rhs);
1460  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1461  return true;
1462  return false;
1463 }
1465 //===----------------------------------------------------------------------===//
1466 // RankOp
1467 //===----------------------------------------------------------------------===//
1469 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1470  auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1471  if (!shape)
1472  return {};
1473  int64_t rank = shape.getNumElements();
1474  Builder builder(getContext());
1475  return builder.getIndexAttr(rank);
1476 }
1478 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1479 /// Constant folding fails in cases where only the rank is constant, not the
1480 /// shape itself.
1481 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1482 ///
1483 /// Example:
1484 ///
1485 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1486 /// %rank = shape.rank %shape
1487 ///
1488 /// becomes
1489 ///
1490 /// %rank = shape.const_size 3
1492 namespace {
1493 struct RankShapeOfCanonicalizationPattern
1494  : public OpRewritePattern<shape::RankOp> {
1497  LogicalResult matchAndRewrite(shape::RankOp op,
1498  PatternRewriter &rewriter) const override {
1499  auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1500  if (!shapeOfOp)
1501  return failure();
1502  auto rankedTensorType =
1503  llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1504  if (!rankedTensorType)
1505  return failure();
1506  int64_t rank = rankedTensorType.getRank();
1507  if (llvm::isa<IndexType>(op.getType())) {
1508  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1509  rank);
1510  } else if (llvm::isa<shape::SizeType>(op.getType())) {
1511  rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1512  } else {
1513  return failure();
1514  }
1515  return success();
1516  }
1517 };
1518 } // namespace
1520 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1521  MLIRContext *context) {
1522  patterns.add<RankShapeOfCanonicalizationPattern>(context);
1523 }
1525 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1526  MLIRContext *context, std::optional<Location> location,
1527  RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1528  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1529  inferredReturnTypes.assign({SizeType::get(context)});
1530  else
1531  inferredReturnTypes.assign({IndexType::get(context)});
1532  return success();
1533 }
1535 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1536  // SizeType is compatible with IndexType.
1537  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1538 }
1540 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1542 //===----------------------------------------------------------------------===//
1543 // NumElementsOp
1544 //===----------------------------------------------------------------------===//
1546 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1548  // Fold only when argument constant.
1549  Attribute shape = adaptor.getShape();
1550  if (!shape)
1551  return {};
1553  APInt product(64, 1);
1554  for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1555  product *= value;
1556  Builder builder(getContext());
1557  return builder.getIndexAttr(product.getLimitedValue());
1558 }
1560 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1561  MLIRContext *context, std::optional<Location> location,
1562  NumElementsOp::Adaptor adaptor,
1563  SmallVectorImpl<Type> &inferredReturnTypes) {
1564  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1565  inferredReturnTypes.assign({SizeType::get(context)});
1566  else
1567  inferredReturnTypes.assign({IndexType::get(context)});
1568  return success();
1569 }
1571 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1572  TypeRange r) {
1573  // SizeType is compatible with IndexType.
1574  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1575 }
1577 LogicalResult shape::NumElementsOp::verify() {
1578  return verifySizeOrIndexOp(*this);
1579 }
1581 //===----------------------------------------------------------------------===//
1582 // MaxOp
1583 //===----------------------------------------------------------------------===//
1585 OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1586  // If operands are equal, just propagate one.
1587  if (getLhs() == getRhs())
1588  return getLhs();
1589  return nullptr;
1590 }
1592 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1593  MLIRContext *context, std::optional<Location> location,
1594  MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1595  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1596  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1597  else
1598  inferredReturnTypes.assign({SizeType::get(context)});
1599  return success();
1600 }
1602 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1603  if (l.size() != 1 || r.size() != 1)
1604  return false;
1605  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1606  return true;
1607  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1608  return true;
1609  return false;
1610 }
1612 //===----------------------------------------------------------------------===//
1613 // MinOp
1614 //===----------------------------------------------------------------------===//
1616 OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1617  // If operands are equal, just propagate one.
1618  if (getLhs() == getRhs())
1619  return getLhs();
1620  return nullptr;
1621 }
1623 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1624  MLIRContext *context, std::optional<Location> location,
1625  MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1626  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1627  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1628  else
1629  inferredReturnTypes.assign({SizeType::get(context)});
1630  return success();
1631 }
1633 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1634  if (l.size() != 1 || r.size() != 1)
1635  return false;
1636  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1637  return true;
1638  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1639  return true;
1640  return false;
1641 }
1643 //===----------------------------------------------------------------------===//
1644 // MulOp
1645 //===----------------------------------------------------------------------===//
1647 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1648  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1649  if (!lhs)
1650  return nullptr;
1651  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1652  if (!rhs)
1653  return nullptr;
1654  APInt folded = lhs.getValue() * rhs.getValue();
1655  Type indexTy = IndexType::get(getContext());
1656  return IntegerAttr::get(indexTy, folded);
1657 }
1659 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1660  MLIRContext *context, std::optional<Location> location,
1661  MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1662  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1663  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1664  inferredReturnTypes.assign({SizeType::get(context)});
1665  else
1666  inferredReturnTypes.assign({IndexType::get(context)});
1667  return success();
1668 }
1670 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1671  // SizeType is compatible with IndexType.
1672  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1673 }
1675 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1677 //===----------------------------------------------------------------------===//
1678 // ShapeOfOp
1679 //===----------------------------------------------------------------------===//
1681 namespace {
1682 /// Replace shape_of(x) where x has a constant shape with a const_shape op.
1683 struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1686  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1687  PatternRewriter &rewriter) const override {
1688  auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1689  if (!type || !type.hasStaticShape())
1690  return failure();
1691  Location loc = op.getLoc();
1692  Value constShape =
1693  rewriter
1694  .create<ConstShapeOp>(loc,
1695  rewriter.getIndexTensorAttr(type.getShape()))
1696  .getResult();
1697  if (constShape.getType() != op.getResult().getType())
1698  constShape = rewriter.create<tensor::CastOp>(
1699  loc, op.getResult().getType(), constShape);
1700  rewriter.replaceOp(op, constShape);
1701  return success();
1702  }
1703 };
1705 // Canonicalize
1706 //
1707 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1708 // %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1709 //
1710 // to
1711 //
1712 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1713 // %1 = %shape
1714 //
1715 struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
1718  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1719  PatternRewriter &rewriter) const override {
1720  auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1721  if (!tensorReshapeOp)
1722  return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
1723  if (!isa<TensorType>(op.getType()))
1724  return rewriter.notifyMatchFailure(op, "result is not a tensor");
1726  // Operand 'shape' of 'tensor.reshape' may now be used as the result of
1727  // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1728  // formed IR, it may not be identical (dynamically vs statically shaped),
1729  // in which case it needs to be cast first.
1730  Value shape = tensorReshapeOp.getShape();
1731  if (op.getType() != shape.getType())
1732  shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
1734  rewriter.replaceOp(op, shape);
1735  return success();
1736  }
1737 };
1739 // Canonicalize
1740 // ```
1741 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1742 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1743 // ```
1744 // to
1745 // ```
1746 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1747 // ```
1748 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1751  LogicalResult matchAndRewrite(tensor::CastOp op,
1752  PatternRewriter &rewriter) const override {
1753  auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1754  if (!ty || ty.getRank() != 1)
1755  return failure();
1757  auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1758  if (!shapeOfOp)
1759  return failure();
1761  // Argument type must be ranked and must not conflict.
1762  auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1763  if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1764  return failure();
1766  rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1767  return success();
1768  }
1769 };
1770 } // namespace
1772 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1773  MLIRContext *context) {
1774  patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1775  ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1776  context);
1777 }
1779 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1780  MLIRContext *context, std::optional<Location> location,
1781  ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1782  if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1783  inferredReturnTypes.assign({ShapeType::get(context)});
1784  else {
1785  auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1786  int64_t rank =
1787  shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1788  Type indexTy = IndexType::get(context);
1789  Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1790  inferredReturnTypes.assign({extentTensorTy});
1791  }
1792  return success();
1793 }
1795 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1796  if (l.size() != 1 || r.size() != 1)
1797  return false;
1798  if (l == r)
1799  return true;
1801  Type lhs = l.front();
1802  Type rhs = r.front();
1804  if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1805  !llvm::isa<ShapeType, ShapedType>(rhs))
1806  return false;
1808  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1809  // Shape type is compatible with all other valid return types.
1810  return true;
1812  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1813  return true;
1814  return false;
1815 }
1817 LogicalResult shape::ShapeOfOp::verify() {
1818  return verifyShapeOrExtentTensorOp(*this);
1819 }
1821 //===----------------------------------------------------------------------===//
1822 // SizeToIndexOp
1823 //===----------------------------------------------------------------------===//
1825 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1826  // Constant values of both types, `shape.size` and `index`, are represented as
1827  // `IntegerAttr`s which makes constant folding simple.
1828  if (Attribute arg = adaptor.getArg())
1829  return arg;
1830  return OpFoldResult();
1831 }
1833 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1834  MLIRContext *context) {
1835  patterns.add<IndexToSizeToIndexCanonicalization>(context);
1836 }
1838 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1839  if (inputs.size() != 1 || outputs.size() != 1)
1840  return false;
1841  return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1842  llvm::isa<IndexType>(outputs[0]);
1843 }
1845 //===----------------------------------------------------------------------===//
1846 // YieldOp
1847 //===----------------------------------------------------------------------===//
1849 LogicalResult shape::YieldOp::verify() {
1850  auto *parentOp = (*this)->getParentOp();
1851  auto results = parentOp->getResults();
1852  auto operands = getOperands();
1854  if (parentOp->getNumResults() != getNumOperands())
1855  return emitOpError() << "number of operands does not match number of "
1856  "results of its parent";
1857  for (auto e : llvm::zip(results, operands))
1858  if (std::get<0>(e).getType() != std::get<1>(e).getType())
1859  return emitOpError() << "types mismatch between yield op and its parent";
1861  return success();
1862 }
1864 //===----------------------------------------------------------------------===//
1865 // SplitAtOp
1866 //===----------------------------------------------------------------------===//
1868 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1869  SmallVectorImpl<OpFoldResult> &results) {
1870  if (!adaptor.getOperand() || !adaptor.getIndex())
1871  return failure();
1872  auto shapeVec = llvm::to_vector<6>(
1873  llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1874  auto shape = llvm::ArrayRef(shapeVec);
1875  auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1876  // Verify that the split point is in the correct range.
1877  // TODO: Constant fold to an "error".
1878  int64_t rank = shape.size();
1879  if (-rank > splitPoint || splitPoint > rank)
1880  return failure();
1881  if (splitPoint < 0)
1882  splitPoint += shape.size();
1883  Builder builder(adaptor.getOperand().getContext());
1884  results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1885  results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1886  return success();
1887 }
1889 //===----------------------------------------------------------------------===//
1890 // ToExtentTensorOp
1891 //===----------------------------------------------------------------------===//
1893 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1894  if (!adaptor.getInput())
1895  return OpFoldResult();
1896  Builder builder(getContext());
1897  auto shape = llvm::to_vector<6>(
1898  llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1899  auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1900  builder.getIndexType());
1901  return DenseIntElementsAttr::get(type, shape);
1902 }
1904 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1905  if (inputs.size() != 1 || outputs.size() != 1)
1906  return false;
1907  if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1908  if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1909  inputTensor.getRank() != 1)
1910  return false;
1911  } else if (!llvm::isa<ShapeType>(inputs[0])) {
1912  return false;
1913  }
1915  TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1916  return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1917 }
1919 //===----------------------------------------------------------------------===//
1920 // ReduceOp
1921 //===----------------------------------------------------------------------===//
1923 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1924  ValueRange initVals) {
1925  OpBuilder::InsertionGuard g(builder);
1926  result.addOperands(shape);
1927  result.addOperands(initVals);
1929  Region *bodyRegion = result.addRegion();
1930  Block *bodyBlock = builder.createBlock(
1931  bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
1933  Type elementType;
1934  if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1935  elementType = tensorType.getElementType();
1936  else
1937  elementType = SizeType::get(builder.getContext());
1938  bodyBlock->addArgument(elementType, shape.getLoc());
1940  for (Value initVal : initVals) {
1941  bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
1942  result.addTypes(initVal.getType());
1943  }
1944 }
1946 LogicalResult ReduceOp::verify() {
1947  // Verify block arg types.
1948  Block &block = getRegion().front();
1950  // The block takes index, extent, and aggregated values as arguments.
1951  auto blockArgsCount = getInitVals().size() + 2;
1952  if (block.getNumArguments() != blockArgsCount)
1953  return emitOpError() << "ReduceOp body is expected to have "
1954  << blockArgsCount << " arguments";
1956  // The first block argument is the index and must always be of type `index`.
1957  if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1958  return emitOpError(
1959  "argument 0 of ReduceOp body is expected to be of IndexType");
1961  // The second block argument is the extent and must be of type `size` or
1962  // `index`, depending on whether the reduce operation is applied to a shape or
1963  // to an extent tensor.
1964  Type extentTy = block.getArgument(1).getType();
1965  if (llvm::isa<ShapeType>(getShape().getType())) {
1966  if (!llvm::isa<SizeType>(extentTy))
1967  return emitOpError("argument 1 of ReduceOp body is expected to be of "
1968  "SizeType if the ReduceOp operates on a ShapeType");
1969  } else {
1970  if (!llvm::isa<IndexType>(extentTy))
1971  return emitOpError(
1972  "argument 1 of ReduceOp body is expected to be of IndexType if the "
1973  "ReduceOp operates on an extent tensor");
1974  }
1976  for (const auto &type : llvm::enumerate(getInitVals()))
1977  if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1978  return emitOpError() << "type mismatch between argument "
1979  << type.index() + 2
1980  << " of ReduceOp body and initial value "
1981  << type.index();
1982  return success();
1983 }
1985 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1986  // Parse operands.
1988  Type shapeOrExtentTensorType;
1989  if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1991  parser.parseColonType(shapeOrExtentTensorType) ||
1992  parser.parseOptionalArrowTypeList(result.types))
1993  return failure();
1995  // Resolve operands.
1996  auto initVals = llvm::ArrayRef(operands).drop_front();
1997  if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1998  result.operands) ||
1999  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2000  result.operands))
2001  return failure();
2003  // Parse the body.
2004  Region *body = result.addRegion();
2005  if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
2006  return failure();
2008  // Parse attributes.
2009  if (parser.parseOptionalAttrDict(result.attributes))
2010  return failure();
2012  return success();
2013 }
2015 void ReduceOp::print(OpAsmPrinter &p) {
2016  p << '(' << getShape() << ", " << getInitVals()
2017  << ") : " << getShape().getType();
2018  p.printOptionalArrowTypeList(getResultTypes());
2019  p << ' ';
2020  p.printRegion(getRegion());
2021  p.printOptionalAttrDict((*this)->getAttrs());
2022 }
2024 #define GET_OP_CLASSES
2025 #include "mlir/Dialect/Shape/IR/"
2028 #include "mlir/Dialect/Shape/IR/"
