MLIR  18.0.0git
Shape.cpp
Go to the documentation of this file.
1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 
16 #include "mlir/Dialect/Traits.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
26 #include "llvm/ADT/SetOperations.h"
27 #include "llvm/ADT/SmallString.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 using namespace mlir;
32 using namespace mlir::shape;
33 
34 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
35 
36 namespace {
37 #include "ShapeCanonicalization.inc"
38 } // namespace
39 
40 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
41  return RankedTensorType::get({rank}, IndexType::get(ctx));
42 }
43 
45  auto ranked = llvm::dyn_cast<RankedTensorType>(type);
46  return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
47 }
48 
50  SmallVectorImpl<int64_t> &shapeValues) {
51  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
52  auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
53  if (!type.hasRank())
54  return failure();
55  llvm::append_range(shapeValues, type.getShape());
56  return success();
57  }
59  if (matchPattern(input, m_Constant(&attr))) {
60  llvm::append_range(shapeValues, attr.getValues<int64_t>());
61  return success();
62  }
63  return failure();
64 }
65 
66 static bool isErrorPropagationPossible(TypeRange operandTypes) {
67  return llvm::any_of(operandTypes, [](Type ty) {
68  return llvm::isa<SizeType, ShapeType, ValueShapeType>(ty);
69  });
70 }
71 
73  assert(op != nullptr && op->getNumResults() == 1);
74  Type resultTy = op->getResultTypes().front();
76  if (!llvm::isa<SizeType>(resultTy))
77  return op->emitOpError()
78  << "if at least one of the operands can hold error values then "
79  "the result must be of type `size` to propagate them";
80  }
81  return success();
82 }
83 
85  assert(op != nullptr && op->getNumResults() == 1);
86  Type resultTy = op->getResultTypes().front();
88  if (!llvm::isa<ShapeType>(resultTy))
89  return op->emitOpError()
90  << "if at least one of the operands can hold error values then "
91  "the result must be of type `shape` to propagate them";
92  }
93  return success();
94 }
95 
96 template <typename... Ty>
97 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
98  return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
99 }
100 
101 template <typename... Ty, typename... ranges>
102 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
103  return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // InlinerInterface
108 //===----------------------------------------------------------------------===//
109 
110 namespace {
111 /// This class defines the interface for inlining shape dialect ops.
112 struct ShapeInlinerInterface : public DialectInlinerInterface {
114 
115  // Returns true if the given region 'src' can be inlined into the region
116  // 'dest' that is attached to an operation registered to the current dialect.
117  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
118  IRMapping &) const final {
119  return true;
120  }
121 
122  // Returns true if the given operation 'op', that is registered to this
123  // dialect, can be inlined into the region 'dest' that is attached to an
124  // operation registered to the current dialect.
125  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
126  IRMapping &) const final {
127  return true;
128  }
129 };
130 } // namespace
131 
132 void ShapeDialect::initialize() {
133  addOperations<
134 #define GET_OP_LIST
135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
136  >();
137  addTypes<
138 #define GET_TYPEDEF_LIST
139 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
140  >();
141  addInterfaces<ShapeInlinerInterface>();
142  // Allow unknown operations during prototyping and testing. As the dialect is
143  // still evolving it makes it simple to start with an unregistered ops and
144  // try different variants before actually defining the op.
145  allowUnknownOperations();
146 }
147 
149  Attribute value, Type type,
150  Location loc) {
151  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
152  return builder.create<ub::PoisonOp>(loc, type, poison);
153 
154  if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
155  return builder.create<ConstShapeOp>(
156  loc, type, llvm::cast<DenseIntElementsAttr>(value));
157  if (llvm::isa<SizeType>(type))
158  return builder.create<ConstSizeOp>(loc, type,
159  llvm::cast<IntegerAttr>(value));
160  if (llvm::isa<WitnessType>(type))
161  return builder.create<ConstWitnessOp>(loc, type,
162  llvm::cast<BoolAttr>(value));
163 
164  return arith::ConstantOp::materialize(builder, value, type, loc);
165 }
166 
167 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
168  NamedAttribute attribute) {
169  // Verify shape.lib attribute.
170  if (attribute.getName() == "shape.lib") {
171  if (!op->hasTrait<OpTrait::SymbolTable>())
172  return op->emitError(
173  "shape.lib attribute may only be on op implementing SymbolTable");
174 
175  if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
176  auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
177  if (!symbol)
178  return op->emitError("shape function library ")
179  << symbolRef << " not found";
180  return isa<shape::FunctionLibraryOp>(symbol)
181  ? success()
182  : op->emitError()
183  << symbolRef << " required to be shape function library";
184  }
185 
186  if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
187  // Verify all entries are function libraries and mappings in libraries
188  // refer to unique ops.
190  for (auto it : arr) {
191  if (!llvm::isa<SymbolRefAttr>(it))
192  return op->emitError(
193  "only SymbolRefAttr allowed in shape.lib attribute array");
194 
195  auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
196  SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
197  if (!shapeFnLib)
198  return op->emitError()
199  << it << " does not refer to FunctionLibraryOp";
200  for (auto mapping : shapeFnLib.getMapping()) {
201  if (!key.insert(mapping.getName()).second) {
202  return op->emitError("only one op to shape mapping allowed, found "
203  "multiple for `")
204  << mapping.getName() << "`";
205  }
206  }
207  }
208  return success();
209  }
210 
211  return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
212  "allowed as shape.lib attribute");
213  }
214  return success();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // AnyOp
219 //===----------------------------------------------------------------------===//
220 
221 // TODO: Canonicalization should be implemented for shapes that can be
222 // determined through mixtures of the known dimensions of the inputs.
223 OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
224  // Only the last operand is checked because AnyOp is commutative.
225  if (adaptor.getInputs().back())
226  return adaptor.getInputs().back();
227 
228  return nullptr;
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // AssumingOp
233 //===----------------------------------------------------------------------===//
234 
235 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
236  result.regions.reserve(1);
237  Region *doRegion = result.addRegion();
238 
239  auto &builder = parser.getBuilder();
241  if (parser.parseOperand(cond) ||
242  parser.resolveOperand(cond, builder.getType<WitnessType>(),
243  result.operands))
244  return failure();
245 
246  // Parse optional results type list.
247  if (parser.parseOptionalArrowTypeList(result.types))
248  return failure();
249 
250  // Parse the region and add a terminator if elided.
251  if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
252  return failure();
253  AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
254 
255  // Parse the optional attribute list.
256  if (parser.parseOptionalAttrDict(result.attributes))
257  return failure();
258  return success();
259 }
260 
262  bool yieldsResults = !getResults().empty();
263 
264  p << " " << getWitness();
265  if (yieldsResults)
266  p << " -> (" << getResultTypes() << ")";
267  p << ' ';
268  p.printRegion(getDoRegion(),
269  /*printEntryBlockArgs=*/false,
270  /*printBlockTerminators=*/yieldsResults);
271  p.printOptionalAttrDict((*this)->getAttrs());
272 }
273 
274 namespace {
275 // Removes AssumingOp with a passing witness and inlines the region.
276 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
278 
279  LogicalResult matchAndRewrite(AssumingOp op,
280  PatternRewriter &rewriter) const override {
281  auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
282  if (!witness || !witness.getPassingAttr())
283  return failure();
284 
285  AssumingOp::inlineRegionIntoParent(op, rewriter);
286  return success();
287  }
288 };
289 
290 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
292 
293  LogicalResult matchAndRewrite(AssumingOp op,
294  PatternRewriter &rewriter) const override {
295  Block *body = op.getBody();
296  auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
297 
298  // Find used values.
299  SmallVector<Value, 4> newYieldOperands;
300  for (auto [opResult, yieldOperand] :
301  llvm::zip(op.getResults(), yieldOp.getOperands())) {
302  if (!opResult.getUses().empty()) {
303  newYieldOperands.push_back(yieldOperand);
304  }
305  }
306 
307  // Rewrite only if redundant results exist.
308  if (newYieldOperands.size() == yieldOp->getNumOperands())
309  return failure();
310 
311  // Replace yield op in the old assuming op's body and move the entire region
312  // to the new assuming op.
313  rewriter.setInsertionPointToEnd(body);
314  auto newYieldOp =
315  rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
316  rewriter.setInsertionPoint(op);
317  auto newOp = rewriter.create<AssumingOp>(
318  op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
319  newOp.getDoRegion().takeBody(op.getDoRegion());
320 
321  // Use the new results to replace the previously used ones.
322  SmallVector<Value, 4> replacementValues;
323  auto src = newOp.getResults().begin();
324  for (auto it : op.getResults()) {
325  if (it.getUses().empty())
326  replacementValues.push_back(nullptr);
327  else
328  replacementValues.push_back(*src++);
329  }
330  rewriter.replaceOp(op, replacementValues);
331  return success();
332  }
333 };
334 } // namespace
335 
336 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
337  MLIRContext *context) {
338  patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
339 }
340 
341 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
342 void AssumingOp::getSuccessorRegions(
344  // AssumingOp has unconditional control flow into the region and back to the
345  // parent, so return the correct RegionSuccessor purely based on the index
346  // being None or 0.
347  if (!point.isParent()) {
348  regions.push_back(RegionSuccessor(getResults()));
349  return;
350  }
351 
352  regions.push_back(RegionSuccessor(&getDoRegion()));
353 }
354 
355 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
356  PatternRewriter &rewriter) {
357  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
358  auto *assumingBlock = op.getBody();
359  auto initPosition = rewriter.getInsertionPoint();
360  auto *blockAfterAssuming =
361  rewriter.splitBlock(blockBeforeAssuming, initPosition);
362 
363  // Remove the AssumingOp and AssumingYieldOp.
364  auto &yieldOp = assumingBlock->back();
365  rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
366  rewriter.replaceOp(op, yieldOp.getOperands());
367  rewriter.eraseOp(&yieldOp);
368 
369  // Merge blocks together as there was no branching behavior from the
370  // AssumingOp.
371  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
372  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
373 }
374 
375 void AssumingOp::build(
376  OpBuilder &builder, OperationState &result, Value witness,
378 
379  result.addOperands(witness);
380  Region *bodyRegion = result.addRegion();
381  bodyRegion->push_back(new Block);
382  Block &bodyBlock = bodyRegion->front();
383 
384  // Build body.
385  OpBuilder::InsertionGuard guard(builder);
386  builder.setInsertionPointToStart(&bodyBlock);
387  SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
388  builder.create<AssumingYieldOp>(result.location, yieldValues);
389 
390  SmallVector<Type, 2> assumingTypes;
391  for (Value v : yieldValues)
392  assumingTypes.push_back(v.getType());
393  result.addTypes(assumingTypes);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // AddOp
398 //===----------------------------------------------------------------------===//
399 
400 LogicalResult mlir::shape::AddOp::inferReturnTypes(
401  MLIRContext *context, std::optional<Location> location,
402  AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
403  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
404  llvm::isa<SizeType>(adaptor.getRhs().getType()))
405  inferredReturnTypes.assign({SizeType::get(context)});
406  else
407  inferredReturnTypes.assign({IndexType::get(context)});
408  return success();
409 }
410 
411 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
412  // SizeType is compatible with IndexType.
413  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
414 }
415 
416 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
417  // add(x, 0) -> x
418  if (matchPattern(getRhs(), m_Zero()))
419  return getLhs();
420 
421  return constFoldBinaryOp<IntegerAttr>(
422  adaptor.getOperands(),
423  [](APInt a, const APInt &b) { return std::move(a) + b; });
424 }
425 
427 
428 //===----------------------------------------------------------------------===//
429 // AssumingAllOp
430 //===----------------------------------------------------------------------===//
431 
432 namespace {
433 
434 // Merge multiple `shape.assuming_all` operations together.
435 //
436 // %0 = shape.assuming_all %w0, %w1
437 // %1 = shape.assuming_all %w2, %0
438 //
439 // to:
440 //
441 // %0 = shape.assuming_all %w0, %w2, %w2
442 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
444 
445  LogicalResult matchAndRewrite(AssumingAllOp op,
446  PatternRewriter &rewriter) const override {
447  SmallVector<Value> operands;
448 
449  for (Value operand : op.getInputs()) {
450  if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
451  operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
452  else
453  operands.push_back(operand);
454  }
455 
456  // We didn't find any other `assuming_all` ops to merge with.
457  if (operands.size() == op.getNumOperands())
458  return failure();
459 
460  // Replace with a new `assuming_all` operation with merged constraints.
461  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
462  return success();
463  }
464 };
465 
466 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
467 // are subsumed by others.
468 //
469 // %0 = shape.cstr_broadcastable %shape0, %shape1
470 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
471 //
472 // %2 = shape.cstr_broadcastable %shape3, %shape4
473 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
474 //
475 // %4 = shape.assuming_all %0, %1, %2, %3
476 //
477 // to:
478 //
479 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
480 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
481 // %2 = shape.assuming_all %0, %1
482 //
483 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
484 // shapes [0, 1] are broadcastable too, and can be removed from the list of
485 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
486 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
487 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
489 
490  LogicalResult matchAndRewrite(AssumingAllOp op,
491  PatternRewriter &rewriter) const override {
492  // Collect all `CstrBroadcastableOp` operands first.
494  for (Value operand : op.getInputs()) {
495  // TODO: Apply this optimization if some of the witnesses are not
496  // produced by the `cstr_broadcastable`.
497  auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
498  if (!broadcastable)
499  return failure();
500 
501  operands.insert(broadcastable);
502  }
503 
504  // Skip trivial `assuming_all` operations.
505  if (operands.size() <= 1)
506  return failure();
507 
508  // Collect shapes checked by `cstr_broadcastable` operands.
510  for (auto cstr : operands) {
511  DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
512  shapes.emplace_back(cstr, std::move(shapesSet));
513  }
514 
515  // Sort by the number of shape operands (larger to smaller).
516  llvm::sort(shapes, [](auto a, auto b) {
517  return a.first.getNumOperands() > b.first.getNumOperands();
518  });
519 
520  // We start from the `cst_broadcastable` operations with largest number of
521  // shape operands, and remove redundant `cst_broadcastable` operations. We
522  // do this until we find a set of `cst_broadcastable` operations with
523  // non-overlapping constraints.
524  SmallVector<CstrBroadcastableOp> markedForErase;
525 
526  for (unsigned i = 0; i < shapes.size(); ++i) {
527  auto isSubset = [&](auto pair) {
528  return llvm::set_is_subset(pair.second, shapes[i].second);
529  };
530 
531  // Keep redundant `cstr_broadcastable` operations to be erased.
532  auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
533  for (auto *it0 = it; it0 < shapes.end(); ++it0)
534  markedForErase.push_back(it0->first);
535  shapes.erase(it, shapes.end());
536  }
537 
538  // We didn't find any operands that could be removed.
539  if (markedForErase.empty())
540  return failure();
541 
542  // Collect non-overlapping `cst_broadcastable` constraints.
543  SmallVector<Value> uniqueConstraints;
544  for (auto &shape : shapes)
545  uniqueConstraints.push_back(shape.first.getResult());
546 
547  // Replace with a new `assuming_all` operation ...
548  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
549 
550  // ... and maybe erase `cstr_broadcastable` ops without uses.
551  for (auto &op : markedForErase)
552  if (op->use_empty())
553  rewriter.eraseOp(op);
554 
555  return success();
556  }
557 };
558 
559 struct AssumingAllToCstrEqCanonicalization
560  : public OpRewritePattern<AssumingAllOp> {
562 
563  LogicalResult matchAndRewrite(AssumingAllOp op,
564  PatternRewriter &rewriter) const override {
565  SmallVector<Value, 8> shapes;
566  for (Value w : op.getInputs()) {
567  auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
568  if (!cstrEqOp)
569  return failure();
570  bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
571  return llvm::is_contained(shapes, s);
572  });
573  if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
574  return failure();
575  shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
576  }
577  rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
578  return success();
579  }
580 };
581 
582 template <typename OpTy>
583 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
585 
586  LogicalResult matchAndRewrite(OpTy op,
587  PatternRewriter &rewriter) const override {
588  // Find unique operands.
589  SetVector<Value> unique(op.operand_begin(), op.operand_end());
590 
591  // Reduce op to equivalent with unique operands.
592  if (unique.size() < op.getNumOperands()) {
593  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
594  unique.takeVector(), op->getAttrs());
595  return success();
596  }
597 
598  return failure();
599  }
600 };
601 } // namespace
602 
603 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
604  MLIRContext *context) {
605  patterns
606  .add<MergeAssumingAllOps, AssumingAllOneOp,
607  AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
608  RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
609 }
610 
611 OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
612  // Iterate in reverse to first handle all constant operands. They are
613  // guaranteed to be the tail of the inputs because this is commutative.
614  for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
615  Attribute a = adaptor.getInputs()[idx];
616  // Cannot fold if any inputs are not constant;
617  if (!a)
618  return nullptr;
619 
620  // We do not need to keep statically known values after handling them in
621  // this method.
622  getOperation()->eraseOperand(idx);
623 
624  // Always false if any input is statically known false
625  if (!llvm::cast<BoolAttr>(a).getValue())
626  return a;
627  }
628  // If this is reached, all inputs were statically known passing.
629  return BoolAttr::get(getContext(), true);
630 }
631 
633  // Ensure that AssumingAllOp contains at least one operand
634  if (getNumOperands() == 0)
635  return emitOpError("no operands specified");
636 
637  return success();
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // BroadcastOp
642 //===----------------------------------------------------------------------===//
643 
644 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
645  if (getShapes().size() == 1) {
646  // Otherwise, we need a cast which would be a canonicalization, not folding.
647  if (getShapes().front().getType() != getType())
648  return nullptr;
649  return getShapes().front();
650  }
651 
652  // TODO: Support folding with more than 2 input shapes
653  if (getShapes().size() > 2)
654  return nullptr;
655 
656  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
657  return nullptr;
658  auto lhsShape = llvm::to_vector<6>(
659  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
660  .getValues<int64_t>());
661  auto rhsShape = llvm::to_vector<6>(
662  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
663  .getValues<int64_t>());
664  SmallVector<int64_t, 6> resultShape;
665 
666  // If the shapes are not compatible, we can't fold it.
667  // TODO: Fold to an "error".
668  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
669  return nullptr;
670 
671  Builder builder(getContext());
672  return builder.getIndexTensorAttr(resultShape);
673 }
674 
676  return verifyShapeOrExtentTensorOp(*this);
677 }
678 
679 namespace {
680 template <typename OpTy>
681 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
683 
684  LogicalResult matchAndRewrite(OpTy op,
685  PatternRewriter &rewriter) const override {
686  auto isPotentiallyNonEmptyShape = [](Value shape) {
687  if (auto extentTensorTy =
688  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
689  if (extentTensorTy.getDimSize(0) == 0)
690  return false;
691  }
692  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
693  if (constShape.getShape().empty())
694  return false;
695  }
696  return true;
697  };
698  auto newOperands = llvm::to_vector<8>(
699  llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
700 
701  // Reduce op to equivalent without empty shape operands.
702  if (newOperands.size() < op.getNumOperands()) {
703  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
704  op->getAttrs());
705  return success();
706  }
707 
708  return failure();
709  }
710 };
711 
712 struct BroadcastForwardSingleOperandPattern
713  : public OpRewritePattern<BroadcastOp> {
715 
716  LogicalResult matchAndRewrite(BroadcastOp op,
717  PatternRewriter &rewriter) const override {
718  if (op.getNumOperands() != 1)
719  return failure();
720  Value replacement = op.getShapes().front();
721 
722  // Insert cast if needed.
723  if (replacement.getType() != op.getType()) {
724  auto loc = op.getLoc();
725  if (llvm::isa<ShapeType>(op.getType())) {
726  replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
727  } else {
728  assert(!llvm::isa<ShapeType>(op.getType()) &&
729  !llvm::isa<ShapeType>(replacement.getType()) &&
730  "expect extent tensor cast");
731  replacement =
732  rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
733  }
734  }
735 
736  rewriter.replaceOp(op, replacement);
737  return success();
738  }
739 };
740 
741 struct BroadcastFoldConstantOperandsPattern
742  : public OpRewritePattern<BroadcastOp> {
744 
745  LogicalResult matchAndRewrite(BroadcastOp op,
746  PatternRewriter &rewriter) const override {
747  SmallVector<int64_t, 8> foldedConstantShape;
748  SmallVector<Value, 8> newShapeOperands;
749  for (Value shape : op.getShapes()) {
750  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
751  SmallVector<int64_t, 8> newFoldedConstantShape;
753  foldedConstantShape,
754  llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
755  newFoldedConstantShape)) {
756  foldedConstantShape = newFoldedConstantShape;
757  continue;
758  }
759  }
760  newShapeOperands.push_back(shape);
761  }
762 
763  // Need at least two constant operands to fold anything.
764  if (op.getNumOperands() - newShapeOperands.size() < 2)
765  return failure();
766 
767  auto foldedConstantOperandsTy = RankedTensorType::get(
768  {static_cast<int64_t>(foldedConstantShape.size())},
769  rewriter.getIndexType());
770  newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
771  op.getLoc(), foldedConstantOperandsTy,
772  rewriter.getIndexTensorAttr(foldedConstantShape)));
773  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
774  newShapeOperands);
775  return success();
776  }
777 };
778 
779 template <typename OpTy>
780 struct CanonicalizeCastExtentTensorOperandsPattern
781  : public OpRewritePattern<OpTy> {
783 
784  LogicalResult matchAndRewrite(OpTy op,
785  PatternRewriter &rewriter) const override {
786  // Canonicalize operands.
787  bool anyChange = false;
788  auto canonicalizeOperand = [&](Value operand) -> Value {
789  if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
790  // Only eliminate the cast if it holds no shape information.
791  bool isInformationLoosingCast =
792  llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
793  if (isInformationLoosingCast) {
794  anyChange = true;
795  return castOp.getSource();
796  }
797  }
798  return operand;
799  };
800  auto newOperands = llvm::to_vector<8>(
801  llvm::map_range(op.getOperands(), canonicalizeOperand));
802 
803  // Rewrite op if any change required.
804  if (!anyChange)
805  return failure();
806  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
807  return success();
808  }
809 };
810 
811 struct BroadcastConcretizeResultTypePattern
812  : public OpRewritePattern<BroadcastOp> {
814 
815  LogicalResult matchAndRewrite(BroadcastOp op,
816  PatternRewriter &rewriter) const override {
817  // Only concretize dynamic extent tensor result types.
818  auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
819  if (!resultTy || !resultTy.isDynamicDim(0))
820  return failure();
821 
822  // Infer resulting shape rank if possible.
823  int64_t maxRank = 0;
824  for (Value shape : op.getShapes()) {
825  if (auto extentTensorTy =
826  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
827  // Cannot infer resulting shape rank if any operand is dynamically
828  // ranked.
829  if (extentTensorTy.isDynamicDim(0))
830  return failure();
831  maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
832  }
833  }
834 
835  auto newOp = rewriter.create<BroadcastOp>(
836  op.getLoc(), getExtentTensorType(getContext(), maxRank),
837  op.getShapes());
838  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
839  return success();
840  }
841 };
842 } // namespace
843 
844 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
845  MLIRContext *context) {
846  patterns.add<BroadcastConcretizeResultTypePattern,
847  BroadcastFoldConstantOperandsPattern,
848  BroadcastForwardSingleOperandPattern,
849  CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
850  RemoveDuplicateOperandsPattern<BroadcastOp>,
851  RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
852 }
853 
854 //===----------------------------------------------------------------------===//
855 // ConcatOp
856 //===----------------------------------------------------------------------===//
857 
858 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
859  if (!adaptor.getLhs() || !adaptor.getRhs())
860  return nullptr;
861  auto lhsShape = llvm::to_vector<6>(
862  llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
863  auto rhsShape = llvm::to_vector<6>(
864  llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
865  SmallVector<int64_t, 6> resultShape;
866  resultShape.append(lhsShape.begin(), lhsShape.end());
867  resultShape.append(rhsShape.begin(), rhsShape.end());
868  Builder builder(getContext());
869  return builder.getIndexTensorAttr(resultShape);
870 }
871 
872 //===----------------------------------------------------------------------===//
873 // ConstShapeOp
874 //===----------------------------------------------------------------------===//
875 
877  p << " ";
878  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
879  p << "[";
880  interleaveComma(getShape().getValues<int64_t>(), p);
881  p << "] : ";
882  p.printType(getType());
883 }
884 
885 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
886  if (parser.parseOptionalAttrDict(result.attributes))
887  return failure();
888  // We piggy-back on ArrayAttr parsing, though we don't internally store the
889  // shape as an ArrayAttr.
890  // TODO: Implement custom parser and maybe make syntax a bit more concise.
891  Attribute extentsRaw;
892  NamedAttrList dummy;
893  if (parser.parseAttribute(extentsRaw, "dummy", dummy))
894  return failure();
895  auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
896  if (!extentsArray)
897  return failure();
899  for (Attribute extent : extentsArray) {
900  IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
901  if (!attr)
902  return failure();
903  ints.push_back(attr.getInt());
904  }
905  Builder &builder = parser.getBuilder();
906  result.addAttribute("shape", builder.getIndexTensorAttr(ints));
907  Type resultTy;
908  if (parser.parseColonType(resultTy))
909  return failure();
910  result.types.push_back(resultTy);
911  return success();
912 }
913 
914 OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
915 
916 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
917  MLIRContext *context) {
918  patterns.add<TensorCastConstShape>(context);
919 }
920 
921 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
922  MLIRContext *context, std::optional<Location> location,
923  ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
924  Builder b(context);
925  const Properties prop = adaptor.getProperties();
926  inferredReturnTypes.assign({RankedTensorType::get(
927  {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
928  return success();
929 }
930 
931 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
932  TypeRange r) {
933  if (l.size() != 1 || r.size() != 1)
934  return false;
935 
936  Type lhs = l.front();
937  Type rhs = r.front();
938 
939  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
940  // Shape type is compatible with all other valid return types.
941  return true;
942  return lhs == rhs;
943 }
944 
945 //===----------------------------------------------------------------------===//
946 // CstrBroadcastableOp
947 //===----------------------------------------------------------------------===//
948 
949 void CstrBroadcastableOp::getCanonicalizationPatterns(
950  RewritePatternSet &patterns, MLIRContext *context) {
951  // Canonicalization patterns have overlap with the considerations during
952  // folding in case additional shape information is inferred at some point that
953  // does not result in folding.
954  patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
955  CstrBroadcastableEqOps,
956  RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
957  RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
958 }
959 
960 // Return true if there is exactly one attribute not representing a scalar
961 // broadcast.
963  bool nonScalarSeen = false;
964  for (Attribute a : attributes) {
965  if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
966  if (nonScalarSeen)
967  return false;
968  nonScalarSeen = true;
969  }
970  }
971  return true;
972 }
973 
974 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
975  // No broadcasting is needed if all operands but one are scalar.
976  if (hasAtMostSingleNonScalar(adaptor.getShapes()))
977  return BoolAttr::get(getContext(), true);
978 
979  if ([&] {
981  for (const auto &operand : adaptor.getShapes()) {
982  if (!operand)
983  return false;
984  extents.push_back(llvm::to_vector<6>(
985  llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
986  }
988  }())
989  return BoolAttr::get(getContext(), true);
990 
991  // Lastly, see if folding can be completed based on what constraints are known
992  // on the input shapes.
993  if ([&] {
995  for (auto shapeValue : getShapes()) {
996  extents.emplace_back();
997  if (failed(getShapeVec(shapeValue, extents.back())))
998  return false;
999  }
1001  }())
1002  return BoolAttr::get(getContext(), true);
1003 
1004  // Because a failing witness result here represents an eventual assertion
1005  // failure, we do not replace it with a constant witness.
1006  return nullptr;
1007 }
1008 
1010  // Ensure that CstrBroadcastableOp contains at least two operands
1011  if (getNumOperands() < 2)
1012  return emitOpError("required at least 2 input shapes");
1013  return success();
1014 }
1015 
1016 //===----------------------------------------------------------------------===//
1017 // CstrEqOp
1018 //===----------------------------------------------------------------------===//
1019 
1020 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1021  MLIRContext *context) {
1022  // If inputs are equal, return passing witness
1023  patterns.add<CstrEqEqOps>(context);
1024 }
1025 
1026 OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1027  if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1028  return a && a == adaptor.getShapes().front();
1029  }))
1030  return BoolAttr::get(getContext(), true);
1031 
1032  // Because a failing witness result here represents an eventual assertion
1033  // failure, we do not try to replace it with a constant witness. Similarly, we
1034  // cannot if there are any non-const inputs.
1035  return nullptr;
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // ConstSizeOp
1040 //===----------------------------------------------------------------------===//
1041 
1042 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1043  int64_t value) {
1044  build(builder, result, builder.getIndexAttr(value));
1045 }
1046 
1047 OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1048 
1049 void ConstSizeOp::getAsmResultNames(
1050  llvm::function_ref<void(Value, StringRef)> setNameFn) {
1051  SmallString<4> buffer;
1052  llvm::raw_svector_ostream os(buffer);
1053  os << "c" << getValue();
1054  setNameFn(getResult(), os.str());
1055 }
1056 
1057 //===----------------------------------------------------------------------===//
1058 // ConstWitnessOp
1059 //===----------------------------------------------------------------------===//
1060 
1061 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1062 
1063 //===----------------------------------------------------------------------===//
1064 // CstrRequireOp
1065 //===----------------------------------------------------------------------===//
1066 
1067 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1068  return adaptor.getPred();
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // DimOp
1073 //===----------------------------------------------------------------------===//
1074 
1075 std::optional<int64_t> DimOp::getConstantIndex() {
1076  if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1077  return constSizeOp.getValue().getLimitedValue();
1078  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1079  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1080  return std::nullopt;
1081 }
1082 
1083 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1084  Type valType = getValue().getType();
1085  auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1086  if (!valShapedType || !valShapedType.hasRank())
1087  return nullptr;
1088  std::optional<int64_t> index = getConstantIndex();
1089  if (!index.has_value())
1090  return nullptr;
1091  if (index.value() < 0 || index.value() >= valShapedType.getRank())
1092  return nullptr;
1093  auto extent = valShapedType.getDimSize(*index);
1094  if (ShapedType::isDynamic(extent))
1095  return nullptr;
1096  return IntegerAttr::get(IndexType::get(getContext()), extent);
1097 }
1098 
1099 LogicalResult mlir::shape::DimOp::inferReturnTypes(
1100  MLIRContext *context, std::optional<Location> location,
1101  DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1102  inferredReturnTypes.assign({adaptor.getIndex().getType()});
1103  return success();
1104 }
1105 
1106 bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1107  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1108 }
1109 
1110 //===----------------------------------------------------------------------===//
1111 // DivOp
1112 //===----------------------------------------------------------------------===//
1113 
1114 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1115  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1116  if (!lhs)
1117  return nullptr;
1118  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1119  if (!rhs)
1120  return nullptr;
1121 
1122  // Division in APInt does not follow floor(lhs, rhs) when the result is
1123  // negative. Rather, APInt rounds toward zero.
1124  APInt quotient, remainder;
1125  APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1126  if (quotient.isNegative() && !remainder.isZero()) {
1127  quotient -= 1;
1128  }
1129 
1130  Type indexTy = IndexType::get(getContext());
1131  return IntegerAttr::get(indexTy, quotient);
1132 }
1133 
1134 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1135  MLIRContext *context, std::optional<Location> location,
1136  DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1137  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1138  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1139  inferredReturnTypes.assign({SizeType::get(context)});
1140  else
1141  inferredReturnTypes.assign({IndexType::get(context)});
1142  return success();
1143 }
1144 
1145 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1146  // SizeType is compatible with IndexType.
1147  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1148 }
1149 
1150 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1151 
1152 //===----------------------------------------------------------------------===//
1153 // ShapeEqOp
1154 //===----------------------------------------------------------------------===//
1155 
1156 OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1157  bool allSame = true;
1158  if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1159  return {};
1160  for (Attribute operand : adaptor.getShapes().drop_front()) {
1161  if (!operand)
1162  return {};
1163  allSame = allSame && operand == adaptor.getShapes().front();
1164  }
1165  return BoolAttr::get(getContext(), allSame);
1166 }
1167 
1168 //===----------------------------------------------------------------------===//
1169 // IndexToSizeOp
1170 //===----------------------------------------------------------------------===//
1171 
1172 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1173  // Constant values of both types, `shape.size` and `index`, are represented as
1174  // `IntegerAttr`s which makes constant folding simple.
1175  if (Attribute arg = adaptor.getArg())
1176  return arg;
1177  return {};
1178 }
1179 
1180 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1181  MLIRContext *context) {
1182  patterns.add<SizeToIndexToSizeCanonicalization>(context);
1183 }
1184 
1185 //===----------------------------------------------------------------------===//
1186 // FromExtentsOp
1187 //===----------------------------------------------------------------------===//
1188 
1189 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1190  if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
1191  return nullptr;
1192  SmallVector<int64_t, 6> extents;
1193  for (auto attr : adaptor.getExtents())
1194  extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1195  Builder builder(getContext());
1196  return builder.getIndexTensorAttr(extents);
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // FunctionLibraryOp
1201 //===----------------------------------------------------------------------===//
1202 
1203 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1204  StringRef name) {
1205  result.attributes.push_back(builder.getNamedAttr(
1207 }
1208 
1209 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1210  auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1211  getMapping().get(op->getName().getIdentifier()));
1212  if (!attr)
1213  return nullptr;
1214  return lookupSymbol<FuncOp>(attr);
1215 }
1216 
1217 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1218  OperationState &result) {
1219  // Parse the op name.
1220  StringAttr nameAttr;
1221  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1222  result.attributes))
1223  return failure();
1224 
1225  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1226  return failure();
1227 
1228  auto *bodyRegion = result.addRegion();
1229  if (parser.parseRegion(*bodyRegion))
1230  return failure();
1231 
1232  if (parser.parseKeyword("mapping"))
1233  return failure();
1234 
1235  DictionaryAttr mappingAttr;
1236  if (parser.parseAttribute(mappingAttr,
1237  parser.getBuilder().getType<NoneType>(), "mapping",
1238  result.attributes))
1239  return failure();
1240  return success();
1241 }
1242 
1244  p << ' ';
1245  p.printSymbolName(getName());
1247  (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1248  p << ' ';
1249  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1250  /*printBlockTerminators=*/false);
1251  p << " mapping ";
1252  p.printAttributeWithoutType(getMappingAttr());
1253 }
1254 
1255 //===----------------------------------------------------------------------===//
1256 // FuncOp
1257 //===----------------------------------------------------------------------===//
1258 
1259 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1260  ArrayRef<NamedAttribute> attrs) {
1261  OpBuilder builder(location->getContext());
1262  OperationState state(location, getOperationName());
1263  FuncOp::build(builder, state, name, type, attrs);
1264  return cast<FuncOp>(Operation::create(state));
1265 }
1266 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1268  SmallVector<NamedAttribute, 8> attrRef(attrs);
1269  return create(location, name, type, llvm::ArrayRef(attrRef));
1270 }
1271 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1273  ArrayRef<DictionaryAttr> argAttrs) {
1274  FuncOp func = create(location, name, type, attrs);
1275  func.setAllArgAttrs(argAttrs);
1276  return func;
1277 }
1278 
1279 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1280  FunctionType type, ArrayRef<NamedAttribute> attrs,
1281  ArrayRef<DictionaryAttr> argAttrs) {
1282  state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1283  builder.getStringAttr(name));
1284  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1285  TypeAttr::get(type));
1286  state.attributes.append(attrs.begin(), attrs.end());
1287  state.addRegion();
1288 
1289  if (argAttrs.empty())
1290  return;
1291  assert(type.getNumInputs() == argAttrs.size());
1293  builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
1294  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1295 }
1296 
1297 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1298  auto buildFuncType =
1299  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1301  std::string &) { return builder.getFunctionType(argTypes, results); };
1302 
1304  parser, result, /*allowVariadic=*/false,
1305  getFunctionTypeAttrName(result.name), buildFuncType,
1306  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1307 }
1308 
1309 void FuncOp::print(OpAsmPrinter &p) {
1311  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1312  getArgAttrsAttrName(), getResAttrsAttrName());
1313 }
1314 
1315 //===----------------------------------------------------------------------===//
1316 // GetExtentOp
1317 //===----------------------------------------------------------------------===//
1318 
1319 std::optional<int64_t> GetExtentOp::getConstantDim() {
1320  if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1321  return constSizeOp.getValue().getLimitedValue();
1322  if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1323  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1324  return std::nullopt;
1325 }
1326 
1327 OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1328  auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1329  if (!elements)
1330  return nullptr;
1331  std::optional<int64_t> dim = getConstantDim();
1332  if (!dim.has_value())
1333  return nullptr;
1334  if (dim.value() >= elements.getNumElements())
1335  return nullptr;
1336  return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1337 }
1338 
1339 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1340  int64_t dim) {
1341  auto loc = result.location;
1342  auto dimAttr = builder.getIndexAttr(dim);
1343  if (llvm::isa<ShapeType>(shape.getType())) {
1344  Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1345  build(builder, result, builder.getType<SizeType>(), shape, dim);
1346  } else {
1347  Value dim =
1348  builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1349  build(builder, result, builder.getIndexType(), shape, dim);
1350  }
1351 }
1352 
1353 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1354  MLIRContext *context, std::optional<Location> location,
1355  GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1356  inferredReturnTypes.assign({IndexType::get(context)});
1357  return success();
1358 }
1359 
1360 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1361  TypeRange r) {
1362  // SizeType is compatible with IndexType.
1363  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1364 }
1365 
1367 
1368 //===----------------------------------------------------------------------===//
1369 // IsBroadcastableOp
1370 //===----------------------------------------------------------------------===//
1371 
1372 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1373  MLIRContext *context) {
1374  patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1375 }
1376 
1377 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1378  // Can always broadcast fewer than two shapes.
1379  if (adaptor.getShapes().size() < 2) {
1380  return BoolAttr::get(getContext(), true);
1381  }
1382 
1383  return nullptr;
1384 }
1385 
1386 //===----------------------------------------------------------------------===//
1387 // MeetOp
1388 //===----------------------------------------------------------------------===//
1389 
1390 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1391  MLIRContext *context, std::optional<Location> location,
1392  MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1393  if (adaptor.getOperands().empty())
1394  return failure();
1395 
1396  auto isShapeType = [](Type arg) {
1397  if (llvm::isa<ShapeType>(arg))
1398  return true;
1399  return isExtentTensorType(arg);
1400  };
1401 
1402  ValueRange::type_range types = adaptor.getOperands().getTypes();
1403  Type acc = types.front();
1404  for (auto t : drop_begin(types)) {
1405  Type l = acc, r = t;
1406  if (!llvm::isa<ShapeType, SizeType>(l))
1407  std::swap(l, r);
1408 
1409  // Handle sizes, propagate error type if present.
1410  if (llvm::isa<SizeType>(l)) {
1411  if (llvm::isa<SizeType, IndexType>(r))
1412  acc = l;
1413  else
1414  return emitOptionalError(location, "requires all sizes or shapes");
1415  } else if (llvm::isa<IndexType>(l)) {
1416  if (llvm::isa<IndexType>(r))
1417  acc = r;
1418  else
1419  return emitOptionalError(location, "requires all sizes or shapes");
1420  } else if (llvm::isa<ShapeType>(l)) {
1421  // Handle shapes, propagate error type if present.
1422  if (isShapeType(r))
1423  acc = l;
1424  else
1425  return emitOptionalError(location, "requires all sizes or shapes");
1426  } else if (isExtentTensorType(l)) {
1427  auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1428  auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1429  if (ShapedType::isDynamic(rank1))
1430  acc = l;
1431  else if (ShapedType::isDynamic(rank2))
1432  acc = r;
1433  else if (rank1 != rank2)
1434  return emitOptionalError(location, "unequal shape cardinality");
1435  else
1436  acc = l;
1437  }
1438  }
1439  inferredReturnTypes.assign({acc});
1440  return success();
1441 }
1442 
1443 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1444  if (l.size() != 1 || r.size() != 1)
1445  return false;
1446  if (l == r)
1447  return true;
1448 
1449  Type lhs = l.front();
1450  Type rhs = r.front();
1451 
1452  if (!llvm::isa<ShapeType, SizeType>(lhs))
1453  std::swap(lhs, rhs);
1454 
1455  if (llvm::isa<SizeType>(lhs))
1456  return llvm::isa<SizeType, IndexType>(rhs);
1457  if (llvm::isa<ShapeType>(lhs))
1458  return llvm::isa<ShapeType, TensorType>(rhs);
1459 
1460  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1461  return true;
1462  return false;
1463 }
1464 
1465 //===----------------------------------------------------------------------===//
1466 // RankOp
1467 //===----------------------------------------------------------------------===//
1468 
1469 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1470  auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1471  if (!shape)
1472  return {};
1473  int64_t rank = shape.getNumElements();
1474  Builder builder(getContext());
1475  return builder.getIndexAttr(rank);
1476 }
1477 
1478 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1479 /// Constant folding fails in cases where only the rank is constant, not the
1480 /// shape itself.
1481 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1482 ///
1483 /// Example:
1484 ///
1485 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1486 /// %rank = shape.rank %shape
1487 ///
1488 /// becomes
1489 ///
1490 /// %rank = shape.const_size 3
1491 
1492 namespace {
1493 struct RankShapeOfCanonicalizationPattern
1494  : public OpRewritePattern<shape::RankOp> {
1496 
1497  LogicalResult matchAndRewrite(shape::RankOp op,
1498  PatternRewriter &rewriter) const override {
1499  auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1500  if (!shapeOfOp)
1501  return failure();
1502  auto rankedTensorType =
1503  llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1504  if (!rankedTensorType)
1505  return failure();
1506  int64_t rank = rankedTensorType.getRank();
1507  if (llvm::isa<IndexType>(op.getType())) {
1508  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1509  rank);
1510  } else if (llvm::isa<shape::SizeType>(op.getType())) {
1511  rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1512  } else {
1513  return failure();
1514  }
1515  return success();
1516  }
1517 };
1518 } // namespace
1519 
1520 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1521  MLIRContext *context) {
1522  patterns.add<RankShapeOfCanonicalizationPattern>(context);
1523 }
1524 
1525 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1526  MLIRContext *context, std::optional<Location> location,
1527  RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1528  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1529  inferredReturnTypes.assign({SizeType::get(context)});
1530  else
1531  inferredReturnTypes.assign({IndexType::get(context)});
1532  return success();
1533 }
1534 
1535 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1536  // SizeType is compatible with IndexType.
1537  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1538 }
1539 
1541 
1542 //===----------------------------------------------------------------------===//
1543 // NumElementsOp
1544 //===----------------------------------------------------------------------===//
1545 
1546 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1547 
1548  // Fold only when argument constant.
1549  Attribute shape = adaptor.getShape();
1550  if (!shape)
1551  return {};
1552 
1553  APInt product(64, 1);
1554  for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1555  product *= value;
1556  Builder builder(getContext());
1557  return builder.getIndexAttr(product.getLimitedValue());
1558 }
1559 
1560 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1561  MLIRContext *context, std::optional<Location> location,
1562  NumElementsOp::Adaptor adaptor,
1563  SmallVectorImpl<Type> &inferredReturnTypes) {
1564  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1565  inferredReturnTypes.assign({SizeType::get(context)});
1566  else
1567  inferredReturnTypes.assign({IndexType::get(context)});
1568  return success();
1569 }
1570 
1571 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1572  TypeRange r) {
1573  // SizeType is compatible with IndexType.
1574  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1575 }
1576 
1578  return verifySizeOrIndexOp(*this);
1579 }
1580 
1581 //===----------------------------------------------------------------------===//
1582 // MaxOp
1583 //===----------------------------------------------------------------------===//
1584 
1585 OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1586  // If operands are equal, just propagate one.
1587  if (getLhs() == getRhs())
1588  return getLhs();
1589  return nullptr;
1590 }
1591 
1592 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1593  MLIRContext *context, std::optional<Location> location,
1594  MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1595  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1596  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1597  else
1598  inferredReturnTypes.assign({SizeType::get(context)});
1599  return success();
1600 }
1601 
1602 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1603  if (l.size() != 1 || r.size() != 1)
1604  return false;
1605  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1606  return true;
1607  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1608  return true;
1609  return false;
1610 }
1611 
1612 //===----------------------------------------------------------------------===//
1613 // MinOp
1614 //===----------------------------------------------------------------------===//
1615 
1616 OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1617  // If operands are equal, just propagate one.
1618  if (getLhs() == getRhs())
1619  return getLhs();
1620  return nullptr;
1621 }
1622 
1623 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1624  MLIRContext *context, std::optional<Location> location,
1625  MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1626  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1627  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1628  else
1629  inferredReturnTypes.assign({SizeType::get(context)});
1630  return success();
1631 }
1632 
1633 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1634  if (l.size() != 1 || r.size() != 1)
1635  return false;
1636  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1637  return true;
1638  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1639  return true;
1640  return false;
1641 }
1642 
1643 //===----------------------------------------------------------------------===//
1644 // MulOp
1645 //===----------------------------------------------------------------------===//
1646 
1647 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1648  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1649  if (!lhs)
1650  return nullptr;
1651  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1652  if (!rhs)
1653  return nullptr;
1654  APInt folded = lhs.getValue() * rhs.getValue();
1655  Type indexTy = IndexType::get(getContext());
1656  return IntegerAttr::get(indexTy, folded);
1657 }
1658 
1659 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1660  MLIRContext *context, std::optional<Location> location,
1661  MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1662  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1663  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1664  inferredReturnTypes.assign({SizeType::get(context)});
1665  else
1666  inferredReturnTypes.assign({IndexType::get(context)});
1667  return success();
1668 }
1669 
1670 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1671  // SizeType is compatible with IndexType.
1672  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1673 }
1674 
1676 
1677 //===----------------------------------------------------------------------===//
1678 // ShapeOfOp
1679 //===----------------------------------------------------------------------===//
1680 
1681 OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
1682  auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
1683  if (!type || !type.hasStaticShape())
1684  return nullptr;
1685  Builder builder(getContext());
1686  return builder.getIndexTensorAttr(type.getShape());
1687 }
1688 
1689 namespace {
1690 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1692 
1693  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1694  PatternRewriter &rewriter) const override {
1695  if (!llvm::isa<ShapedType>(op.getArg().getType()))
1696  return failure();
1697  if (llvm::isa<ShapedType>(op.getType()))
1698  return failure();
1699 
1700  rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1701  op.getArg());
1702  return success();
1703  }
1704 };
1705 
1706 // Canonicalize
1707 // ```
1708 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1709 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1710 // ```
1711 // to
1712 // ```
1713 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1714 // ```
1715 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1717 
1718  LogicalResult matchAndRewrite(tensor::CastOp op,
1719  PatternRewriter &rewriter) const override {
1720  auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1721  if (!ty || ty.getRank() != 1)
1722  return failure();
1723 
1724  auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1725  if (!shapeOfOp)
1726  return failure();
1727 
1728  // Argument type must be ranked and must not conflict.
1729  auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1730  if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1731  return failure();
1732 
1733  rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1734  return success();
1735  }
1736 };
1737 } // namespace
1738 
1739 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1740  MLIRContext *context) {
1741  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1742  ExtractFromShapeOfExtentTensor>(context);
1743 }
1744 
1745 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1746  MLIRContext *context, std::optional<Location> location,
1747  ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1748  if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1749  inferredReturnTypes.assign({ShapeType::get(context)});
1750  else {
1751  auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1752  int64_t rank =
1753  shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1754  Type indexTy = IndexType::get(context);
1755  Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1756  inferredReturnTypes.assign({extentTensorTy});
1757  }
1758  return success();
1759 }
1760 
1761 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1762  if (l.size() != 1 || r.size() != 1)
1763  return false;
1764  if (l == r)
1765  return true;
1766 
1767  Type lhs = l.front();
1768  Type rhs = r.front();
1769 
1770  if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1771  !llvm::isa<ShapeType, ShapedType>(rhs))
1772  return false;
1773 
1774  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1775  // Shape type is compatible with all other valid return types.
1776  return true;
1777 
1778  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1779  return true;
1780  return false;
1781 }
1782 
1784  return verifyShapeOrExtentTensorOp(*this);
1785 }
1786 
1787 //===----------------------------------------------------------------------===//
1788 // SizeToIndexOp
1789 //===----------------------------------------------------------------------===//
1790 
1791 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1792  // Constant values of both types, `shape.size` and `index`, are represented as
1793  // `IntegerAttr`s which makes constant folding simple.
1794  if (Attribute arg = adaptor.getArg())
1795  return arg;
1796  return OpFoldResult();
1797 }
1798 
1799 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1800  MLIRContext *context) {
1801  patterns.add<IndexToSizeToIndexCanonicalization>(context);
1802 }
1803 
1804 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1805  if (inputs.size() != 1 || outputs.size() != 1)
1806  return false;
1807  return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1808  llvm::isa<IndexType>(outputs[0]);
1809 }
1810 
1811 //===----------------------------------------------------------------------===//
1812 // YieldOp
1813 //===----------------------------------------------------------------------===//
1814 
1816  auto *parentOp = (*this)->getParentOp();
1817  auto results = parentOp->getResults();
1818  auto operands = getOperands();
1819 
1820  if (parentOp->getNumResults() != getNumOperands())
1821  return emitOpError() << "number of operands does not match number of "
1822  "results of its parent";
1823  for (auto e : llvm::zip(results, operands))
1824  if (std::get<0>(e).getType() != std::get<1>(e).getType())
1825  return emitOpError() << "types mismatch between yield op and its parent";
1826 
1827  return success();
1828 }
1829 
1830 //===----------------------------------------------------------------------===//
1831 // SplitAtOp
1832 //===----------------------------------------------------------------------===//
1833 
1834 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1835  SmallVectorImpl<OpFoldResult> &results) {
1836  if (!adaptor.getOperand() || !adaptor.getIndex())
1837  return failure();
1838  auto shapeVec = llvm::to_vector<6>(
1839  llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1840  auto shape = llvm::ArrayRef(shapeVec);
1841  auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1842  // Verify that the split point is in the correct range.
1843  // TODO: Constant fold to an "error".
1844  int64_t rank = shape.size();
1845  if (-rank > splitPoint || splitPoint > rank)
1846  return failure();
1847  if (splitPoint < 0)
1848  splitPoint += shape.size();
1849  Builder builder(adaptor.getOperand().getContext());
1850  results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1851  results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1852  return success();
1853 }
1854 
1855 //===----------------------------------------------------------------------===//
1856 // ToExtentTensorOp
1857 //===----------------------------------------------------------------------===//
1858 
1859 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1860  if (!adaptor.getInput())
1861  return OpFoldResult();
1862  Builder builder(getContext());
1863  auto shape = llvm::to_vector<6>(
1864  llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1865  auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1866  builder.getIndexType());
1867  return DenseIntElementsAttr::get(type, shape);
1868 }
1869 
1870 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1871  if (inputs.size() != 1 || outputs.size() != 1)
1872  return false;
1873  if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1874  if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1875  inputTensor.getRank() != 1)
1876  return false;
1877  } else if (!llvm::isa<ShapeType>(inputs[0])) {
1878  return false;
1879  }
1880 
1881  TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1882  return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1883 }
1884 
1885 //===----------------------------------------------------------------------===//
1886 // ReduceOp
1887 //===----------------------------------------------------------------------===//
1888 
1889 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1890  ValueRange initVals) {
1891  result.addOperands(shape);
1892  result.addOperands(initVals);
1893 
1894  Region *bodyRegion = result.addRegion();
1895  bodyRegion->push_back(new Block);
1896  Block &bodyBlock = bodyRegion->front();
1897  bodyBlock.addArgument(builder.getIndexType(), result.location);
1898 
1899  Type elementType;
1900  if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1901  elementType = tensorType.getElementType();
1902  else
1903  elementType = SizeType::get(builder.getContext());
1904  bodyBlock.addArgument(elementType, shape.getLoc());
1905 
1906  for (Value initVal : initVals) {
1907  bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1908  result.addTypes(initVal.getType());
1909  }
1910 }
1911 
1913  // Verify block arg types.
1914  Block &block = getRegion().front();
1915 
1916  // The block takes index, extent, and aggregated values as arguments.
1917  auto blockArgsCount = getInitVals().size() + 2;
1918  if (block.getNumArguments() != blockArgsCount)
1919  return emitOpError() << "ReduceOp body is expected to have "
1920  << blockArgsCount << " arguments";
1921 
1922  // The first block argument is the index and must always be of type `index`.
1923  if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1924  return emitOpError(
1925  "argument 0 of ReduceOp body is expected to be of IndexType");
1926 
1927  // The second block argument is the extent and must be of type `size` or
1928  // `index`, depending on whether the reduce operation is applied to a shape or
1929  // to an extent tensor.
1930  Type extentTy = block.getArgument(1).getType();
1931  if (llvm::isa<ShapeType>(getShape().getType())) {
1932  if (!llvm::isa<SizeType>(extentTy))
1933  return emitOpError("argument 1 of ReduceOp body is expected to be of "
1934  "SizeType if the ReduceOp operates on a ShapeType");
1935  } else {
1936  if (!llvm::isa<IndexType>(extentTy))
1937  return emitOpError(
1938  "argument 1 of ReduceOp body is expected to be of IndexType if the "
1939  "ReduceOp operates on an extent tensor");
1940  }
1941 
1942  for (const auto &type : llvm::enumerate(getInitVals()))
1943  if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1944  return emitOpError() << "type mismatch between argument "
1945  << type.index() + 2
1946  << " of ReduceOp body and initial value "
1947  << type.index();
1948  return success();
1949 }
1950 
1951 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1952  // Parse operands.
1954  Type shapeOrExtentTensorType;
1955  if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1957  parser.parseColonType(shapeOrExtentTensorType) ||
1958  parser.parseOptionalArrowTypeList(result.types))
1959  return failure();
1960 
1961  // Resolve operands.
1962  auto initVals = llvm::ArrayRef(operands).drop_front();
1963  if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1964  result.operands) ||
1965  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1966  result.operands))
1967  return failure();
1968 
1969  // Parse the body.
1970  Region *body = result.addRegion();
1971  if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1972  return failure();
1973 
1974  // Parse attributes.
1975  if (parser.parseOptionalAttrDict(result.attributes))
1976  return failure();
1977 
1978  return success();
1979 }
1980 
1981 void ReduceOp::print(OpAsmPrinter &p) {
1982  p << '(' << getShape() << ", " << getInitVals()
1983  << ") : " << getShape().getType();
1984  p.printOptionalArrowTypeList(getResultTypes());
1985  p << ' ';
1986  p.printRegion(getRegion());
1987  p.printOptionalAttrDict((*this)->getAttrs());
1988 }
1989 
1990 #define GET_OP_CLASSES
1991 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1992 
1993 #define GET_TYPEDEF_CLASSES
1994 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition: Shape.cpp:66
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition: Shape.cpp:962
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition: Shape.cpp:84
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition: Shape.cpp:97
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition: Shape.cpp:72
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1333
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
Operation & back()
Definition: Block.h:145
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:147
Operation & front()
Definition: Block.h:146
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:209
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:110
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:43
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:45
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:198
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:212
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:421
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:427
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:400
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:831
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
operand_iterator operand_begin()
Definition: Operation.h:369
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_iterator operand_end()
Definition: Operation.h:370
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:59
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:90
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
ValueTypeRange< ValueRange > type_range
Definition: ValueRange.h:401
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition: Traits.cpp:25
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
bool isExtentTensorType(Type)
Definition: Shape.cpp:44
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition: Shape.cpp:49
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:40
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:491
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.