MLIR  19.0.0git
Shape.cpp
Go to the documentation of this file.
1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 
17 #include "mlir/Dialect/Traits.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/SetOperations.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 using namespace mlir;
33 using namespace mlir::shape;
34 
35 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
36 
37 namespace {
38 #include "ShapeCanonicalization.inc"
39 } // namespace
40 
41 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
42  return RankedTensorType::get({rank}, IndexType::get(ctx));
43 }
44 
46  auto ranked = llvm::dyn_cast<RankedTensorType>(type);
47  return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
48 }
49 
50 LogicalResult shape::getShapeVec(Value input,
51  SmallVectorImpl<int64_t> &shapeValues) {
52  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
53  auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
54  if (!type.hasRank())
55  return failure();
56  llvm::append_range(shapeValues, type.getShape());
57  return success();
58  }
60  if (matchPattern(input, m_Constant(&attr))) {
61  llvm::append_range(shapeValues, attr.getValues<int64_t>());
62  return success();
63  }
64  return failure();
65 }
66 
67 static bool isErrorPropagationPossible(TypeRange operandTypes) {
68  return llvm::any_of(operandTypes,
69  llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
70 }
71 
72 static LogicalResult verifySizeOrIndexOp(Operation *op) {
73  assert(op != nullptr && op->getNumResults() == 1);
74  Type resultTy = op->getResultTypes().front();
76  if (!llvm::isa<SizeType>(resultTy))
77  return op->emitOpError()
78  << "if at least one of the operands can hold error values then "
79  "the result must be of type `size` to propagate them";
80  }
81  return success();
82 }
83 
84 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
85  assert(op != nullptr && op->getNumResults() == 1);
86  Type resultTy = op->getResultTypes().front();
88  if (!llvm::isa<ShapeType>(resultTy))
89  return op->emitOpError()
90  << "if at least one of the operands can hold error values then "
91  "the result must be of type `shape` to propagate them";
92  }
93  return success();
94 }
95 
96 template <typename... Ty>
97 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
98  return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
99 }
100 
101 template <typename... Ty, typename... ranges>
102 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
103  return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // InlinerInterface
108 //===----------------------------------------------------------------------===//
109 
110 namespace {
111 /// This class defines the interface for inlining shape dialect ops.
112 struct ShapeInlinerInterface : public DialectInlinerInterface {
114 
115  // Returns true if the given region 'src' can be inlined into the region
116  // 'dest' that is attached to an operation registered to the current dialect.
117  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
118  IRMapping &) const final {
119  return true;
120  }
121 
122  // Returns true if the given operation 'op', that is registered to this
123  // dialect, can be inlined into the region 'dest' that is attached to an
124  // operation registered to the current dialect.
125  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
126  IRMapping &) const final {
127  return true;
128  }
129 };
130 } // namespace
131 
132 void ShapeDialect::initialize() {
133  addOperations<
134 #define GET_OP_LIST
135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
136  >();
137  addTypes<
138 #define GET_TYPEDEF_LIST
139 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
140  >();
141  addInterfaces<ShapeInlinerInterface>();
142  // Allow unknown operations during prototyping and testing. As the dialect is
143  // still evolving it makes it simple to start with an unregistered ops and
144  // try different variants before actually defining the op.
145  allowUnknownOperations();
146  declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
147  AssumingYieldOp>();
148 }
149 
151  Attribute value, Type type,
152  Location loc) {
153  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
154  return builder.create<ub::PoisonOp>(loc, type, poison);
155 
156  if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
157  return builder.create<ConstShapeOp>(
158  loc, type, llvm::cast<DenseIntElementsAttr>(value));
159  if (llvm::isa<SizeType>(type))
160  return builder.create<ConstSizeOp>(loc, type,
161  llvm::cast<IntegerAttr>(value));
162  if (llvm::isa<WitnessType>(type))
163  return builder.create<ConstWitnessOp>(loc, type,
164  llvm::cast<BoolAttr>(value));
165 
166  return arith::ConstantOp::materialize(builder, value, type, loc);
167 }
168 
169 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
170  NamedAttribute attribute) {
171  // Verify shape.lib attribute.
172  if (attribute.getName() == "shape.lib") {
173  if (!op->hasTrait<OpTrait::SymbolTable>())
174  return op->emitError(
175  "shape.lib attribute may only be on op implementing SymbolTable");
176 
177  if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
178  auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
179  if (!symbol)
180  return op->emitError("shape function library ")
181  << symbolRef << " not found";
182  return isa<shape::FunctionLibraryOp>(symbol)
183  ? success()
184  : op->emitError()
185  << symbolRef << " required to be shape function library";
186  }
187 
188  if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
189  // Verify all entries are function libraries and mappings in libraries
190  // refer to unique ops.
192  for (auto it : arr) {
193  if (!llvm::isa<SymbolRefAttr>(it))
194  return op->emitError(
195  "only SymbolRefAttr allowed in shape.lib attribute array");
196 
197  auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
198  SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
199  if (!shapeFnLib)
200  return op->emitError()
201  << it << " does not refer to FunctionLibraryOp";
202  for (auto mapping : shapeFnLib.getMapping()) {
203  if (!key.insert(mapping.getName()).second) {
204  return op->emitError("only one op to shape mapping allowed, found "
205  "multiple for `")
206  << mapping.getName() << "`";
207  }
208  }
209  }
210  return success();
211  }
212 
213  return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
214  "allowed as shape.lib attribute");
215  }
216  return success();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // AnyOp
221 //===----------------------------------------------------------------------===//
222 
223 // TODO: Canonicalization should be implemented for shapes that can be
224 // determined through mixtures of the known dimensions of the inputs.
225 OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
226  // Only the last operand is checked because AnyOp is commutative.
227  if (adaptor.getInputs().back())
228  return adaptor.getInputs().back();
229 
230  return nullptr;
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // AssumingOp
235 //===----------------------------------------------------------------------===//
236 
237 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
238  result.regions.reserve(1);
239  Region *doRegion = result.addRegion();
240 
241  auto &builder = parser.getBuilder();
243  if (parser.parseOperand(cond) ||
244  parser.resolveOperand(cond, builder.getType<WitnessType>(),
245  result.operands))
246  return failure();
247 
248  // Parse optional results type list.
249  if (parser.parseOptionalArrowTypeList(result.types))
250  return failure();
251 
252  // Parse the region and add a terminator if elided.
253  if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
254  return failure();
255  AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
256 
257  // Parse the optional attribute list.
258  if (parser.parseOptionalAttrDict(result.attributes))
259  return failure();
260  return success();
261 }
262 
264  bool yieldsResults = !getResults().empty();
265 
266  p << " " << getWitness();
267  if (yieldsResults)
268  p << " -> (" << getResultTypes() << ")";
269  p << ' ';
270  p.printRegion(getDoRegion(),
271  /*printEntryBlockArgs=*/false,
272  /*printBlockTerminators=*/yieldsResults);
273  p.printOptionalAttrDict((*this)->getAttrs());
274 }
275 
276 namespace {
277 // Removes AssumingOp with a passing witness and inlines the region.
278 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
280 
281  LogicalResult matchAndRewrite(AssumingOp op,
282  PatternRewriter &rewriter) const override {
283  auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
284  if (!witness || !witness.getPassingAttr())
285  return failure();
286 
287  AssumingOp::inlineRegionIntoParent(op, rewriter);
288  return success();
289  }
290 };
291 
292 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
294 
295  LogicalResult matchAndRewrite(AssumingOp op,
296  PatternRewriter &rewriter) const override {
297  Block *body = op.getBody();
298  auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
299 
300  // Find used values.
301  SmallVector<Value, 4> newYieldOperands;
302  for (auto [opResult, yieldOperand] :
303  llvm::zip(op.getResults(), yieldOp.getOperands())) {
304  if (!opResult.getUses().empty()) {
305  newYieldOperands.push_back(yieldOperand);
306  }
307  }
308 
309  // Rewrite only if redundant results exist.
310  if (newYieldOperands.size() == yieldOp->getNumOperands())
311  return failure();
312 
313  // Replace yield op in the old assuming op's body and move the entire region
314  // to the new assuming op.
315  rewriter.setInsertionPointToEnd(body);
316  auto newYieldOp =
317  rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
318  rewriter.setInsertionPoint(op);
319  auto newOp = rewriter.create<AssumingOp>(
320  op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
321  newOp.getDoRegion().takeBody(op.getDoRegion());
322 
323  // Use the new results to replace the previously used ones.
324  SmallVector<Value, 4> replacementValues;
325  auto src = newOp.getResults().begin();
326  for (auto it : op.getResults()) {
327  if (it.getUses().empty())
328  replacementValues.push_back(nullptr);
329  else
330  replacementValues.push_back(*src++);
331  }
332  rewriter.replaceOp(op, replacementValues);
333  return success();
334  }
335 };
336 } // namespace
337 
338 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
339  MLIRContext *context) {
340  patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
341 }
342 
343 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
344 void AssumingOp::getSuccessorRegions(
346  // AssumingOp has unconditional control flow into the region and back to the
347  // parent, so return the correct RegionSuccessor purely based on the index
348  // being None or 0.
349  if (!point.isParent()) {
350  regions.push_back(RegionSuccessor(getResults()));
351  return;
352  }
353 
354  regions.push_back(RegionSuccessor(&getDoRegion()));
355 }
356 
357 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
358  PatternRewriter &rewriter) {
359  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
360  auto *assumingBlock = op.getBody();
361  auto initPosition = rewriter.getInsertionPoint();
362  auto *blockAfterAssuming =
363  rewriter.splitBlock(blockBeforeAssuming, initPosition);
364 
365  // Remove the AssumingOp and AssumingYieldOp.
366  auto &yieldOp = assumingBlock->back();
367  rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
368  rewriter.replaceOp(op, yieldOp.getOperands());
369  rewriter.eraseOp(&yieldOp);
370 
371  // Merge blocks together as there was no branching behavior from the
372  // AssumingOp.
373  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
374  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
375 }
376 
377 void AssumingOp::build(
378  OpBuilder &builder, OperationState &result, Value witness,
380  OpBuilder::InsertionGuard g(builder);
381 
382  result.addOperands(witness);
383  Region *bodyRegion = result.addRegion();
384  builder.createBlock(bodyRegion);
385 
386  // Build body.
387  SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
388  builder.create<AssumingYieldOp>(result.location, yieldValues);
389 
390  SmallVector<Type, 2> assumingTypes;
391  for (Value v : yieldValues)
392  assumingTypes.push_back(v.getType());
393  result.addTypes(assumingTypes);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // AddOp
398 //===----------------------------------------------------------------------===//
399 
400 LogicalResult mlir::shape::AddOp::inferReturnTypes(
401  MLIRContext *context, std::optional<Location> location,
402  AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
403  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
404  llvm::isa<SizeType>(adaptor.getRhs().getType()))
405  inferredReturnTypes.assign({SizeType::get(context)});
406  else
407  inferredReturnTypes.assign({IndexType::get(context)});
408  return success();
409 }
410 
411 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
412  // SizeType is compatible with IndexType.
413  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
414 }
415 
416 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
417  // add(x, 0) -> x
418  if (matchPattern(getRhs(), m_Zero()))
419  return getLhs();
420 
421  return constFoldBinaryOp<IntegerAttr>(
422  adaptor.getOperands(),
423  [](APInt a, const APInt &b) { return std::move(a) + b; });
424 }
425 
426 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
427 
428 //===----------------------------------------------------------------------===//
429 // AssumingAllOp
430 //===----------------------------------------------------------------------===//
431 
432 namespace {
433 
434 // Merge multiple `shape.assuming_all` operations together.
435 //
436 // %0 = shape.assuming_all %w0, %w1
437 // %1 = shape.assuming_all %w2, %0
438 //
439 // to:
440 //
441 // %0 = shape.assuming_all %w0, %w2, %w2
442 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
444 
445  LogicalResult matchAndRewrite(AssumingAllOp op,
446  PatternRewriter &rewriter) const override {
447  SmallVector<Value> operands;
448 
449  for (Value operand : op.getInputs()) {
450  if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
451  operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
452  else
453  operands.push_back(operand);
454  }
455 
456  // We didn't find any other `assuming_all` ops to merge with.
457  if (operands.size() == op.getNumOperands())
458  return failure();
459 
460  // Replace with a new `assuming_all` operation with merged constraints.
461  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
462  return success();
463  }
464 };
465 
466 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
467 // are subsumed by others.
468 //
469 // %0 = shape.cstr_broadcastable %shape0, %shape1
470 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
471 //
472 // %2 = shape.cstr_broadcastable %shape3, %shape4
473 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
474 //
475 // %4 = shape.assuming_all %0, %1, %2, %3
476 //
477 // to:
478 //
479 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
480 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
481 // %2 = shape.assuming_all %0, %1
482 //
483 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
484 // shapes [0, 1] are broadcastable too, and can be removed from the list of
485 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
486 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
487 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
489 
490  LogicalResult matchAndRewrite(AssumingAllOp op,
491  PatternRewriter &rewriter) const override {
492  // Collect all `CstrBroadcastableOp` operands first.
494  for (Value operand : op.getInputs()) {
495  // TODO: Apply this optimization if some of the witnesses are not
496  // produced by the `cstr_broadcastable`.
497  auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
498  if (!broadcastable)
499  return failure();
500 
501  operands.insert(broadcastable);
502  }
503 
504  // Skip trivial `assuming_all` operations.
505  if (operands.size() <= 1)
506  return failure();
507 
508  // Collect shapes checked by `cstr_broadcastable` operands.
510  for (auto cstr : operands) {
511  DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
512  shapes.emplace_back(cstr, std::move(shapesSet));
513  }
514 
515  // Sort by the number of shape operands (larger to smaller).
516  llvm::sort(shapes, [](auto a, auto b) {
517  return a.first.getNumOperands() > b.first.getNumOperands();
518  });
519 
520  // We start from the `cst_broadcastable` operations with largest number of
521  // shape operands, and remove redundant `cst_broadcastable` operations. We
522  // do this until we find a set of `cst_broadcastable` operations with
523  // non-overlapping constraints.
524  SmallVector<CstrBroadcastableOp> markedForErase;
525 
526  for (unsigned i = 0; i < shapes.size(); ++i) {
527  auto isSubset = [&](auto pair) {
528  return llvm::set_is_subset(pair.second, shapes[i].second);
529  };
530 
531  // Keep redundant `cstr_broadcastable` operations to be erased.
532  auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
533  for (auto *it0 = it; it0 < shapes.end(); ++it0)
534  markedForErase.push_back(it0->first);
535  shapes.erase(it, shapes.end());
536  }
537 
538  // We didn't find any operands that could be removed.
539  if (markedForErase.empty())
540  return failure();
541 
542  // Collect non-overlapping `cst_broadcastable` constraints.
543  SmallVector<Value> uniqueConstraints;
544  for (auto &shape : shapes)
545  uniqueConstraints.push_back(shape.first.getResult());
546 
547  // Replace with a new `assuming_all` operation ...
548  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
549 
550  // ... and maybe erase `cstr_broadcastable` ops without uses.
551  for (auto &op : markedForErase)
552  if (op->use_empty())
553  rewriter.eraseOp(op);
554 
555  return success();
556  }
557 };
558 
559 struct AssumingAllToCstrEqCanonicalization
560  : public OpRewritePattern<AssumingAllOp> {
562 
563  LogicalResult matchAndRewrite(AssumingAllOp op,
564  PatternRewriter &rewriter) const override {
565  SmallVector<Value, 8> shapes;
566  for (Value w : op.getInputs()) {
567  auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
568  if (!cstrEqOp)
569  return failure();
570  bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
571  return llvm::is_contained(shapes, s);
572  });
573  if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
574  return failure();
575  shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
576  }
577  rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
578  return success();
579  }
580 };
581 
582 template <typename OpTy>
583 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
585 
586  LogicalResult matchAndRewrite(OpTy op,
587  PatternRewriter &rewriter) const override {
588  // Find unique operands.
589  SetVector<Value> unique(op.operand_begin(), op.operand_end());
590 
591  // Reduce op to equivalent with unique operands.
592  if (unique.size() < op.getNumOperands()) {
593  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
594  unique.takeVector(), op->getAttrs());
595  return success();
596  }
597 
598  return failure();
599  }
600 };
601 } // namespace
602 
603 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
604  MLIRContext *context) {
605  patterns
606  .add<MergeAssumingAllOps, AssumingAllOneOp,
607  AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
608  RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
609 }
610 
611 OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
612  // Iterate in reverse to first handle all constant operands. They are
613  // guaranteed to be the tail of the inputs because this is commutative.
614  for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
615  Attribute a = adaptor.getInputs()[idx];
616  // Cannot fold if any inputs are not constant;
617  if (!a)
618  return nullptr;
619 
620  // We do not need to keep statically known values after handling them in
621  // this method.
622  getOperation()->eraseOperand(idx);
623 
624  // Always false if any input is statically known false
625  if (!llvm::cast<BoolAttr>(a).getValue())
626  return a;
627  }
628  // If this is reached, all inputs were statically known passing.
629  return BoolAttr::get(getContext(), true);
630 }
631 
632 LogicalResult AssumingAllOp::verify() {
633  // Ensure that AssumingAllOp contains at least one operand
634  if (getNumOperands() == 0)
635  return emitOpError("no operands specified");
636 
637  return success();
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // BroadcastOp
642 //===----------------------------------------------------------------------===//
643 
644 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
645  if (getShapes().size() == 1) {
646  // Otherwise, we need a cast which would be a canonicalization, not folding.
647  if (getShapes().front().getType() != getType())
648  return nullptr;
649  return getShapes().front();
650  }
651 
652  // TODO: Support folding with more than 2 input shapes
653  if (getShapes().size() > 2)
654  return nullptr;
655 
656  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
657  return nullptr;
658  auto lhsShape = llvm::to_vector<6>(
659  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
660  .getValues<int64_t>());
661  auto rhsShape = llvm::to_vector<6>(
662  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
663  .getValues<int64_t>());
664  SmallVector<int64_t, 6> resultShape;
665 
666  // If the shapes are not compatible, we can't fold it.
667  // TODO: Fold to an "error".
668  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
669  return nullptr;
670 
671  Builder builder(getContext());
672  return builder.getIndexTensorAttr(resultShape);
673 }
674 
675 LogicalResult BroadcastOp::verify() {
676  return verifyShapeOrExtentTensorOp(*this);
677 }
678 
679 namespace {
680 template <typename OpTy>
681 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
683 
684  LogicalResult matchAndRewrite(OpTy op,
685  PatternRewriter &rewriter) const override {
686  auto isPotentiallyNonEmptyShape = [](Value shape) {
687  if (auto extentTensorTy =
688  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
689  if (extentTensorTy.getDimSize(0) == 0)
690  return false;
691  }
692  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
693  if (constShape.getShape().empty())
694  return false;
695  }
696  return true;
697  };
698  auto newOperands = llvm::to_vector<8>(
699  llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
700 
701  // Reduce op to equivalent without empty shape operands.
702  if (newOperands.size() < op.getNumOperands()) {
703  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
704  op->getAttrs());
705  return success();
706  }
707 
708  return failure();
709  }
710 };
711 
712 struct BroadcastForwardSingleOperandPattern
713  : public OpRewritePattern<BroadcastOp> {
715 
716  LogicalResult matchAndRewrite(BroadcastOp op,
717  PatternRewriter &rewriter) const override {
718  if (op.getNumOperands() != 1)
719  return failure();
720  Value replacement = op.getShapes().front();
721 
722  // Insert cast if needed.
723  if (replacement.getType() != op.getType()) {
724  auto loc = op.getLoc();
725  if (llvm::isa<ShapeType>(op.getType())) {
726  replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
727  } else {
728  assert(!llvm::isa<ShapeType>(op.getType()) &&
729  !llvm::isa<ShapeType>(replacement.getType()) &&
730  "expect extent tensor cast");
731  replacement =
732  rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
733  }
734  }
735 
736  rewriter.replaceOp(op, replacement);
737  return success();
738  }
739 };
740 
741 struct BroadcastFoldConstantOperandsPattern
742  : public OpRewritePattern<BroadcastOp> {
744 
745  LogicalResult matchAndRewrite(BroadcastOp op,
746  PatternRewriter &rewriter) const override {
747  SmallVector<int64_t, 8> foldedConstantShape;
748  SmallVector<Value, 8> newShapeOperands;
749  for (Value shape : op.getShapes()) {
750  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
751  SmallVector<int64_t, 8> newFoldedConstantShape;
753  foldedConstantShape,
754  llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
755  newFoldedConstantShape)) {
756  foldedConstantShape = newFoldedConstantShape;
757  continue;
758  }
759  }
760  newShapeOperands.push_back(shape);
761  }
762 
763  // Need at least two constant operands to fold anything.
764  if (op.getNumOperands() - newShapeOperands.size() < 2)
765  return failure();
766 
767  auto foldedConstantOperandsTy = RankedTensorType::get(
768  {static_cast<int64_t>(foldedConstantShape.size())},
769  rewriter.getIndexType());
770  newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
771  op.getLoc(), foldedConstantOperandsTy,
772  rewriter.getIndexTensorAttr(foldedConstantShape)));
773  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
774  newShapeOperands);
775  return success();
776  }
777 };
778 
779 template <typename OpTy>
780 struct CanonicalizeCastExtentTensorOperandsPattern
781  : public OpRewritePattern<OpTy> {
783 
784  LogicalResult matchAndRewrite(OpTy op,
785  PatternRewriter &rewriter) const override {
786  // Canonicalize operands.
787  bool anyChange = false;
788  auto canonicalizeOperand = [&](Value operand) -> Value {
789  if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
790  // Only eliminate the cast if it holds no shape information.
791  bool isInformationLoosingCast =
792  llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
793  if (isInformationLoosingCast) {
794  anyChange = true;
795  return castOp.getSource();
796  }
797  }
798  return operand;
799  };
800  auto newOperands = llvm::to_vector<8>(
801  llvm::map_range(op.getOperands(), canonicalizeOperand));
802 
803  // Rewrite op if any change required.
804  if (!anyChange)
805  return failure();
806  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
807  return success();
808  }
809 };
810 
811 struct BroadcastConcretizeResultTypePattern
812  : public OpRewritePattern<BroadcastOp> {
814 
815  LogicalResult matchAndRewrite(BroadcastOp op,
816  PatternRewriter &rewriter) const override {
817  // Only concretize dynamic extent tensor result types.
818  auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
819  if (!resultTy || !resultTy.isDynamicDim(0))
820  return failure();
821 
822  // Infer resulting shape rank if possible.
823  int64_t maxRank = 0;
824  for (Value shape : op.getShapes()) {
825  if (auto extentTensorTy =
826  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
827  // Cannot infer resulting shape rank if any operand is dynamically
828  // ranked.
829  if (extentTensorTy.isDynamicDim(0))
830  return failure();
831  maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
832  }
833  }
834 
835  auto newOp = rewriter.create<BroadcastOp>(
836  op.getLoc(), getExtentTensorType(getContext(), maxRank),
837  op.getShapes());
838  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
839  return success();
840  }
841 };
842 } // namespace
843 
844 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
845  MLIRContext *context) {
846  patterns.add<BroadcastConcretizeResultTypePattern,
847  BroadcastFoldConstantOperandsPattern,
848  BroadcastForwardSingleOperandPattern,
849  CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
850  RemoveDuplicateOperandsPattern<BroadcastOp>,
851  RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
852 }
853 
854 //===----------------------------------------------------------------------===//
855 // ConcatOp
856 //===----------------------------------------------------------------------===//
857 
858 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
859  if (!adaptor.getLhs() || !adaptor.getRhs())
860  return nullptr;
861  auto lhsShape = llvm::to_vector<6>(
862  llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
863  auto rhsShape = llvm::to_vector<6>(
864  llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
865  SmallVector<int64_t, 6> resultShape;
866  resultShape.append(lhsShape.begin(), lhsShape.end());
867  resultShape.append(rhsShape.begin(), rhsShape.end());
868  Builder builder(getContext());
869  return builder.getIndexTensorAttr(resultShape);
870 }
871 
872 //===----------------------------------------------------------------------===//
873 // ConstShapeOp
874 //===----------------------------------------------------------------------===//
875 
877  p << " ";
878  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
879  p << "[";
880  interleaveComma(getShape().getValues<int64_t>(), p);
881  p << "] : ";
882  p.printType(getType());
883 }
884 
885 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
886  if (parser.parseOptionalAttrDict(result.attributes))
887  return failure();
888  // We piggy-back on ArrayAttr parsing, though we don't internally store the
889  // shape as an ArrayAttr.
890  // TODO: Implement custom parser and maybe make syntax a bit more concise.
891  Attribute extentsRaw;
892  NamedAttrList dummy;
893  if (parser.parseAttribute(extentsRaw, "dummy", dummy))
894  return failure();
895  auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
896  if (!extentsArray)
897  return failure();
899  for (Attribute extent : extentsArray) {
900  IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
901  if (!attr)
902  return failure();
903  ints.push_back(attr.getInt());
904  }
905  Builder &builder = parser.getBuilder();
906  result.addAttribute("shape", builder.getIndexTensorAttr(ints));
907  Type resultTy;
908  if (parser.parseColonType(resultTy))
909  return failure();
910  result.types.push_back(resultTy);
911  return success();
912 }
913 
914 OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
915 
916 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
917  MLIRContext *context) {
918  patterns.add<TensorCastConstShape>(context);
919 }
920 
921 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
922  MLIRContext *context, std::optional<Location> location,
923  ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
924  Builder b(context);
925  const Properties prop = adaptor.getProperties();
926  inferredReturnTypes.assign({RankedTensorType::get(
927  {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
928  return success();
929 }
930 
931 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
932  TypeRange r) {
933  if (l.size() != 1 || r.size() != 1)
934  return false;
935 
936  Type lhs = l.front();
937  Type rhs = r.front();
938 
939  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
940  // Shape type is compatible with all other valid return types.
941  return true;
942  return lhs == rhs;
943 }
944 
945 //===----------------------------------------------------------------------===//
946 // CstrBroadcastableOp
947 //===----------------------------------------------------------------------===//
948 
949 void CstrBroadcastableOp::getCanonicalizationPatterns(
950  RewritePatternSet &patterns, MLIRContext *context) {
951  // Canonicalization patterns have overlap with the considerations during
952  // folding in case additional shape information is inferred at some point that
953  // does not result in folding.
954  patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
955  CstrBroadcastableEqOps,
956  RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
957  RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
958 }
959 
960 // Return true if there is exactly one attribute not representing a scalar
961 // broadcast.
963  bool nonScalarSeen = false;
964  for (Attribute a : attributes) {
965  if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
966  if (nonScalarSeen)
967  return false;
968  nonScalarSeen = true;
969  }
970  }
971  return true;
972 }
973 
974 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
975  // No broadcasting is needed if all operands but one are scalar.
976  if (hasAtMostSingleNonScalar(adaptor.getShapes()))
977  return BoolAttr::get(getContext(), true);
978 
979  if ([&] {
981  for (const auto &operand : adaptor.getShapes()) {
982  if (!operand)
983  return false;
984  extents.push_back(llvm::to_vector<6>(
985  llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
986  }
988  }())
989  return BoolAttr::get(getContext(), true);
990 
991  // Lastly, see if folding can be completed based on what constraints are known
992  // on the input shapes.
993  if ([&] {
995  for (auto shapeValue : getShapes()) {
996  extents.emplace_back();
997  if (failed(getShapeVec(shapeValue, extents.back())))
998  return false;
999  }
1001  }())
1002  return BoolAttr::get(getContext(), true);
1003 
1004  // Because a failing witness result here represents an eventual assertion
1005  // failure, we do not replace it with a constant witness.
1006  return nullptr;
1007 }
1008 
1009 LogicalResult CstrBroadcastableOp::verify() {
1010  // Ensure that CstrBroadcastableOp contains at least two operands
1011  if (getNumOperands() < 2)
1012  return emitOpError("required at least 2 input shapes");
1013  return success();
1014 }
1015 
1016 //===----------------------------------------------------------------------===//
1017 // CstrEqOp
1018 //===----------------------------------------------------------------------===//
1019 
1020 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1021  MLIRContext *context) {
1022  // If inputs are equal, return passing witness
1023  patterns.add<CstrEqEqOps>(context);
1024 }
1025 
1026 OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1027  if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1028  return a && a == adaptor.getShapes().front();
1029  }))
1030  return BoolAttr::get(getContext(), true);
1031 
1032  // Because a failing witness result here represents an eventual assertion
1033  // failure, we do not try to replace it with a constant witness. Similarly, we
1034  // cannot if there are any non-const inputs.
1035  return nullptr;
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // ConstSizeOp
1040 //===----------------------------------------------------------------------===//
1041 
1042 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1043  int64_t value) {
1044  build(builder, result, builder.getIndexAttr(value));
1045 }
1046 
1047 OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1048 
1049 void ConstSizeOp::getAsmResultNames(
1050  llvm::function_ref<void(Value, StringRef)> setNameFn) {
1051  SmallString<4> buffer;
1052  llvm::raw_svector_ostream os(buffer);
1053  os << "c" << getValue();
1054  setNameFn(getResult(), os.str());
1055 }
1056 
1057 //===----------------------------------------------------------------------===//
1058 // ConstWitnessOp
1059 //===----------------------------------------------------------------------===//
1060 
1061 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1062 
1063 //===----------------------------------------------------------------------===//
1064 // CstrRequireOp
1065 //===----------------------------------------------------------------------===//
1066 
1067 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1068  return adaptor.getPred();
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // DimOp
1073 //===----------------------------------------------------------------------===//
1074 
1075 std::optional<int64_t> DimOp::getConstantIndex() {
1076  if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1077  return constSizeOp.getValue().getLimitedValue();
1078  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1079  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1080  return std::nullopt;
1081 }
1082 
1083 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1084  Type valType = getValue().getType();
1085  auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1086  if (!valShapedType || !valShapedType.hasRank())
1087  return nullptr;
1088  std::optional<int64_t> index = getConstantIndex();
1089  if (!index.has_value())
1090  return nullptr;
1091  if (index.value() < 0 || index.value() >= valShapedType.getRank())
1092  return nullptr;
1093  auto extent = valShapedType.getDimSize(*index);
1094  if (ShapedType::isDynamic(extent))
1095  return nullptr;
1096  return IntegerAttr::get(IndexType::get(getContext()), extent);
1097 }
1098 
1099 LogicalResult mlir::shape::DimOp::inferReturnTypes(
1100  MLIRContext *context, std::optional<Location> location,
1101  DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1102  inferredReturnTypes.assign({adaptor.getIndex().getType()});
1103  return success();
1104 }
1105 
1106 bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1107  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1108 }
1109 
1110 //===----------------------------------------------------------------------===//
1111 // DivOp
1112 //===----------------------------------------------------------------------===//
1113 
1114 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1115  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1116  if (!lhs)
1117  return nullptr;
1118  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1119  if (!rhs)
1120  return nullptr;
1121 
1122  // Division in APInt does not follow floor(lhs, rhs) when the result is
1123  // negative. Rather, APInt rounds toward zero.
1124  APInt quotient, remainder;
1125  APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1126  if (quotient.isNegative() && !remainder.isZero()) {
1127  quotient -= 1;
1128  }
1129 
1130  Type indexTy = IndexType::get(getContext());
1131  return IntegerAttr::get(indexTy, quotient);
1132 }
1133 
1134 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1135  MLIRContext *context, std::optional<Location> location,
1136  DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1137  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1138  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1139  inferredReturnTypes.assign({SizeType::get(context)});
1140  else
1141  inferredReturnTypes.assign({IndexType::get(context)});
1142  return success();
1143 }
1144 
1145 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1146  // SizeType is compatible with IndexType.
1147  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1148 }
1149 
1150 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1151 
1152 //===----------------------------------------------------------------------===//
1153 // ShapeEqOp
1154 //===----------------------------------------------------------------------===//
1155 
1156 OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1157  bool allSame = true;
1158  if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1159  return {};
1160  for (Attribute operand : adaptor.getShapes().drop_front()) {
1161  if (!operand)
1162  return {};
1163  allSame = allSame && operand == adaptor.getShapes().front();
1164  }
1165  return BoolAttr::get(getContext(), allSame);
1166 }
1167 
1168 //===----------------------------------------------------------------------===//
1169 // IndexToSizeOp
1170 //===----------------------------------------------------------------------===//
1171 
1172 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1173  // Constant values of both types, `shape.size` and `index`, are represented as
1174  // `IntegerAttr`s which makes constant folding simple.
1175  if (Attribute arg = adaptor.getArg())
1176  return arg;
1177  return {};
1178 }
1179 
1180 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1181  MLIRContext *context) {
1182  patterns.add<SizeToIndexToSizeCanonicalization>(context);
1183 }
1184 
1185 //===----------------------------------------------------------------------===//
1186 // FromExtentsOp
1187 //===----------------------------------------------------------------------===//
1188 
1189 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1190  if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
1191  return nullptr;
1192  SmallVector<int64_t, 6> extents;
1193  for (auto attr : adaptor.getExtents())
1194  extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1195  Builder builder(getContext());
1196  return builder.getIndexTensorAttr(extents);
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // FunctionLibraryOp
1201 //===----------------------------------------------------------------------===//
1202 
1203 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1204  StringRef name) {
1205  result.attributes.push_back(builder.getNamedAttr(
1207 }
1208 
1209 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1210  auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1211  getMapping().get(op->getName().getIdentifier()));
1212  if (!attr)
1213  return nullptr;
1214  return lookupSymbol<FuncOp>(attr);
1215 }
1216 
1217 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1218  OperationState &result) {
1219  // Parse the op name.
1220  StringAttr nameAttr;
1221  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1222  result.attributes))
1223  return failure();
1224 
1225  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1226  return failure();
1227 
1228  auto *bodyRegion = result.addRegion();
1229  if (parser.parseRegion(*bodyRegion))
1230  return failure();
1231 
1232  if (parser.parseKeyword("mapping"))
1233  return failure();
1234 
1235  DictionaryAttr mappingAttr;
1236  if (parser.parseAttribute(mappingAttr,
1237  parser.getBuilder().getType<NoneType>(), "mapping",
1238  result.attributes))
1239  return failure();
1240  return success();
1241 }
1242 
1244  p << ' ';
1245  p.printSymbolName(getName());
1247  (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1248  p << ' ';
1249  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1250  /*printBlockTerminators=*/false);
1251  p << " mapping ";
1252  p.printAttributeWithoutType(getMappingAttr());
1253 }
1254 
1255 //===----------------------------------------------------------------------===//
1256 // FuncOp
1257 //===----------------------------------------------------------------------===//
1258 
1259 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1260  ArrayRef<NamedAttribute> attrs) {
1261  OpBuilder builder(location->getContext());
1262  OperationState state(location, getOperationName());
1263  FuncOp::build(builder, state, name, type, attrs);
1264  return cast<FuncOp>(Operation::create(state));
1265 }
1266 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1268  SmallVector<NamedAttribute, 8> attrRef(attrs);
1269  return create(location, name, type, llvm::ArrayRef(attrRef));
1270 }
1271 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1273  ArrayRef<DictionaryAttr> argAttrs) {
1274  FuncOp func = create(location, name, type, attrs);
1275  func.setAllArgAttrs(argAttrs);
1276  return func;
1277 }
1278 
1279 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1280  FunctionType type, ArrayRef<NamedAttribute> attrs,
1281  ArrayRef<DictionaryAttr> argAttrs) {
1282  state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1283  builder.getStringAttr(name));
1284  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1285  TypeAttr::get(type));
1286  state.attributes.append(attrs.begin(), attrs.end());
1287  state.addRegion();
1288 
1289  if (argAttrs.empty())
1290  return;
1291  assert(type.getNumInputs() == argAttrs.size());
1293  builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
1294  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1295 }
1296 
1297 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1298  auto buildFuncType =
1299  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1301  std::string &) { return builder.getFunctionType(argTypes, results); };
1302 
1304  parser, result, /*allowVariadic=*/false,
1305  getFunctionTypeAttrName(result.name), buildFuncType,
1306  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1307 }
1308 
1309 void FuncOp::print(OpAsmPrinter &p) {
1311  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1312  getArgAttrsAttrName(), getResAttrsAttrName());
1313 }
1314 
1315 //===----------------------------------------------------------------------===//
1316 // GetExtentOp
1317 //===----------------------------------------------------------------------===//
1318 
1319 std::optional<int64_t> GetExtentOp::getConstantDim() {
1320  if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1321  return constSizeOp.getValue().getLimitedValue();
1322  if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1323  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1324  return std::nullopt;
1325 }
1326 
1327 OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1328  auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1329  if (!elements)
1330  return nullptr;
1331  std::optional<int64_t> dim = getConstantDim();
1332  if (!dim.has_value())
1333  return nullptr;
1334  if (dim.value() >= elements.getNumElements())
1335  return nullptr;
1336  return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1337 }
1338 
1339 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1340  int64_t dim) {
1341  auto loc = result.location;
1342  auto dimAttr = builder.getIndexAttr(dim);
1343  if (llvm::isa<ShapeType>(shape.getType())) {
1344  Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1345  build(builder, result, builder.getType<SizeType>(), shape, dim);
1346  } else {
1347  Value dim =
1348  builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1349  build(builder, result, builder.getIndexType(), shape, dim);
1350  }
1351 }
1352 
1353 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1354  MLIRContext *context, std::optional<Location> location,
1355  GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1356  inferredReturnTypes.assign({IndexType::get(context)});
1357  return success();
1358 }
1359 
1360 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1361  TypeRange r) {
1362  // SizeType is compatible with IndexType.
1363  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1364 }
1365 
1366 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1367 
1368 //===----------------------------------------------------------------------===//
1369 // IsBroadcastableOp
1370 //===----------------------------------------------------------------------===//
1371 
1372 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1373  MLIRContext *context) {
1374  patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1375 }
1376 
1377 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1378  // Can always broadcast fewer than two shapes.
1379  if (adaptor.getShapes().size() < 2) {
1380  return BoolAttr::get(getContext(), true);
1381  }
1382 
1383  return nullptr;
1384 }
1385 
1386 //===----------------------------------------------------------------------===//
1387 // MeetOp
1388 //===----------------------------------------------------------------------===//
1389 
1390 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1391  MLIRContext *context, std::optional<Location> location,
1392  MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1393  if (adaptor.getOperands().empty())
1394  return failure();
1395 
1396  auto isShapeType = [](Type arg) {
1397  if (llvm::isa<ShapeType>(arg))
1398  return true;
1399  return isExtentTensorType(arg);
1400  };
1401 
1402  ValueRange::type_range types = adaptor.getOperands().getTypes();
1403  Type acc = types.front();
1404  for (auto t : drop_begin(types)) {
1405  Type l = acc, r = t;
1406  if (!llvm::isa<ShapeType, SizeType>(l))
1407  std::swap(l, r);
1408 
1409  // Handle sizes, propagate error type if present.
1410  if (llvm::isa<SizeType>(l)) {
1411  if (llvm::isa<SizeType, IndexType>(r))
1412  acc = l;
1413  else
1414  return emitOptionalError(location, "requires all sizes or shapes");
1415  } else if (llvm::isa<IndexType>(l)) {
1416  if (llvm::isa<IndexType>(r))
1417  acc = r;
1418  else
1419  return emitOptionalError(location, "requires all sizes or shapes");
1420  } else if (llvm::isa<ShapeType>(l)) {
1421  // Handle shapes, propagate error type if present.
1422  if (isShapeType(r))
1423  acc = l;
1424  else
1425  return emitOptionalError(location, "requires all sizes or shapes");
1426  } else if (isExtentTensorType(l)) {
1427  auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1428  auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1429  if (ShapedType::isDynamic(rank1))
1430  acc = l;
1431  else if (ShapedType::isDynamic(rank2))
1432  acc = r;
1433  else if (rank1 != rank2)
1434  return emitOptionalError(location, "unequal shape cardinality");
1435  else
1436  acc = l;
1437  }
1438  }
1439  inferredReturnTypes.assign({acc});
1440  return success();
1441 }
1442 
1443 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1444  if (l.size() != 1 || r.size() != 1)
1445  return false;
1446  if (l == r)
1447  return true;
1448 
1449  Type lhs = l.front();
1450  Type rhs = r.front();
1451 
1452  if (!llvm::isa<ShapeType, SizeType>(lhs))
1453  std::swap(lhs, rhs);
1454 
1455  if (llvm::isa<SizeType>(lhs))
1456  return llvm::isa<SizeType, IndexType>(rhs);
1457  if (llvm::isa<ShapeType>(lhs))
1458  return llvm::isa<ShapeType, TensorType>(rhs);
1459 
1460  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1461  return true;
1462  return false;
1463 }
1464 
1465 //===----------------------------------------------------------------------===//
1466 // RankOp
1467 //===----------------------------------------------------------------------===//
1468 
1469 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1470  auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1471  if (!shape)
1472  return {};
1473  int64_t rank = shape.getNumElements();
1474  Builder builder(getContext());
1475  return builder.getIndexAttr(rank);
1476 }
1477 
1478 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1479 /// Constant folding fails in cases where only the rank is constant, not the
1480 /// shape itself.
1481 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1482 ///
1483 /// Example:
1484 ///
1485 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1486 /// %rank = shape.rank %shape
1487 ///
1488 /// becomes
1489 ///
1490 /// %rank = shape.const_size 3
1491 
1492 namespace {
1493 struct RankShapeOfCanonicalizationPattern
1494  : public OpRewritePattern<shape::RankOp> {
1496 
1497  LogicalResult matchAndRewrite(shape::RankOp op,
1498  PatternRewriter &rewriter) const override {
1499  auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1500  if (!shapeOfOp)
1501  return failure();
1502  auto rankedTensorType =
1503  llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1504  if (!rankedTensorType)
1505  return failure();
1506  int64_t rank = rankedTensorType.getRank();
1507  if (llvm::isa<IndexType>(op.getType())) {
1508  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1509  rank);
1510  } else if (llvm::isa<shape::SizeType>(op.getType())) {
1511  rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1512  } else {
1513  return failure();
1514  }
1515  return success();
1516  }
1517 };
1518 } // namespace
1519 
1520 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1521  MLIRContext *context) {
1522  patterns.add<RankShapeOfCanonicalizationPattern>(context);
1523 }
1524 
1525 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1526  MLIRContext *context, std::optional<Location> location,
1527  RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1528  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1529  inferredReturnTypes.assign({SizeType::get(context)});
1530  else
1531  inferredReturnTypes.assign({IndexType::get(context)});
1532  return success();
1533 }
1534 
1535 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1536  // SizeType is compatible with IndexType.
1537  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1538 }
1539 
1540 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1541 
1542 //===----------------------------------------------------------------------===//
1543 // NumElementsOp
1544 //===----------------------------------------------------------------------===//
1545 
1546 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1547 
1548  // Fold only when argument constant.
1549  Attribute shape = adaptor.getShape();
1550  if (!shape)
1551  return {};
1552 
1553  APInt product(64, 1);
1554  for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1555  product *= value;
1556  Builder builder(getContext());
1557  return builder.getIndexAttr(product.getLimitedValue());
1558 }
1559 
1560 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1561  MLIRContext *context, std::optional<Location> location,
1562  NumElementsOp::Adaptor adaptor,
1563  SmallVectorImpl<Type> &inferredReturnTypes) {
1564  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1565  inferredReturnTypes.assign({SizeType::get(context)});
1566  else
1567  inferredReturnTypes.assign({IndexType::get(context)});
1568  return success();
1569 }
1570 
1571 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1572  TypeRange r) {
1573  // SizeType is compatible with IndexType.
1574  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1575 }
1576 
1577 LogicalResult shape::NumElementsOp::verify() {
1578  return verifySizeOrIndexOp(*this);
1579 }
1580 
1581 //===----------------------------------------------------------------------===//
1582 // MaxOp
1583 //===----------------------------------------------------------------------===//
1584 
1585 OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1586  // If operands are equal, just propagate one.
1587  if (getLhs() == getRhs())
1588  return getLhs();
1589  return nullptr;
1590 }
1591 
1592 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1593  MLIRContext *context, std::optional<Location> location,
1594  MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1595  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1596  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1597  else
1598  inferredReturnTypes.assign({SizeType::get(context)});
1599  return success();
1600 }
1601 
1602 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1603  if (l.size() != 1 || r.size() != 1)
1604  return false;
1605  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1606  return true;
1607  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1608  return true;
1609  return false;
1610 }
1611 
1612 //===----------------------------------------------------------------------===//
1613 // MinOp
1614 //===----------------------------------------------------------------------===//
1615 
1616 OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1617  // If operands are equal, just propagate one.
1618  if (getLhs() == getRhs())
1619  return getLhs();
1620  return nullptr;
1621 }
1622 
1623 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1624  MLIRContext *context, std::optional<Location> location,
1625  MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1626  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1627  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1628  else
1629  inferredReturnTypes.assign({SizeType::get(context)});
1630  return success();
1631 }
1632 
1633 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1634  if (l.size() != 1 || r.size() != 1)
1635  return false;
1636  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1637  return true;
1638  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1639  return true;
1640  return false;
1641 }
1642 
1643 //===----------------------------------------------------------------------===//
1644 // MulOp
1645 //===----------------------------------------------------------------------===//
1646 
1647 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1648  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1649  if (!lhs)
1650  return nullptr;
1651  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1652  if (!rhs)
1653  return nullptr;
1654  APInt folded = lhs.getValue() * rhs.getValue();
1655  Type indexTy = IndexType::get(getContext());
1656  return IntegerAttr::get(indexTy, folded);
1657 }
1658 
1659 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1660  MLIRContext *context, std::optional<Location> location,
1661  MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1662  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1663  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1664  inferredReturnTypes.assign({SizeType::get(context)});
1665  else
1666  inferredReturnTypes.assign({IndexType::get(context)});
1667  return success();
1668 }
1669 
1670 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1671  // SizeType is compatible with IndexType.
1672  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1673 }
1674 
1675 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1676 
1677 //===----------------------------------------------------------------------===//
1678 // ShapeOfOp
1679 //===----------------------------------------------------------------------===//
1680 
1681 namespace {
1682 /// Replace shape_of(x) where x has a constant shape with a const_shape op.
1683 struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1685 
1686  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1687  PatternRewriter &rewriter) const override {
1688  auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1689  if (!type || !type.hasStaticShape())
1690  return failure();
1691  Location loc = op.getLoc();
1692  Value constShape =
1693  rewriter
1694  .create<ConstShapeOp>(loc,
1695  rewriter.getIndexTensorAttr(type.getShape()))
1696  .getResult();
1697  if (constShape.getType() != op.getResult().getType())
1698  constShape = rewriter.create<tensor::CastOp>(
1699  loc, op.getResult().getType(), constShape);
1700  rewriter.replaceOp(op, constShape);
1701  return success();
1702  }
1703 };
1704 
1705 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1707 
1708  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1709  PatternRewriter &rewriter) const override {
1710  if (!llvm::isa<ShapedType>(op.getArg().getType()))
1711  return failure();
1712  if (llvm::isa<ShapedType>(op.getType()))
1713  return failure();
1714 
1715  rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1716  op.getArg());
1717  return success();
1718  }
1719 };
1720 
1721 // Canonicalize
1722 // ```
1723 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1724 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1725 // ```
1726 // to
1727 // ```
1728 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1729 // ```
1730 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1732 
1733  LogicalResult matchAndRewrite(tensor::CastOp op,
1734  PatternRewriter &rewriter) const override {
1735  auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1736  if (!ty || ty.getRank() != 1)
1737  return failure();
1738 
1739  auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1740  if (!shapeOfOp)
1741  return failure();
1742 
1743  // Argument type must be ranked and must not conflict.
1744  auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1745  if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1746  return failure();
1747 
1748  rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1749  return success();
1750  }
1751 };
1752 } // namespace
1753 
1754 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1755  MLIRContext *context) {
1756  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1757  ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1758  context);
1759 }
1760 
1761 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1762  MLIRContext *context, std::optional<Location> location,
1763  ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1764  if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1765  inferredReturnTypes.assign({ShapeType::get(context)});
1766  else {
1767  auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1768  int64_t rank =
1769  shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1770  Type indexTy = IndexType::get(context);
1771  Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1772  inferredReturnTypes.assign({extentTensorTy});
1773  }
1774  return success();
1775 }
1776 
1777 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1778  if (l.size() != 1 || r.size() != 1)
1779  return false;
1780  if (l == r)
1781  return true;
1782 
1783  Type lhs = l.front();
1784  Type rhs = r.front();
1785 
1786  if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1787  !llvm::isa<ShapeType, ShapedType>(rhs))
1788  return false;
1789 
1790  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1791  // Shape type is compatible with all other valid return types.
1792  return true;
1793 
1794  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1795  return true;
1796  return false;
1797 }
1798 
1799 LogicalResult shape::ShapeOfOp::verify() {
1800  return verifyShapeOrExtentTensorOp(*this);
1801 }
1802 
1803 //===----------------------------------------------------------------------===//
1804 // SizeToIndexOp
1805 //===----------------------------------------------------------------------===//
1806 
1807 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1808  // Constant values of both types, `shape.size` and `index`, are represented as
1809  // `IntegerAttr`s which makes constant folding simple.
1810  if (Attribute arg = adaptor.getArg())
1811  return arg;
1812  return OpFoldResult();
1813 }
1814 
1815 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1816  MLIRContext *context) {
1817  patterns.add<IndexToSizeToIndexCanonicalization>(context);
1818 }
1819 
1820 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1821  if (inputs.size() != 1 || outputs.size() != 1)
1822  return false;
1823  return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1824  llvm::isa<IndexType>(outputs[0]);
1825 }
1826 
1827 //===----------------------------------------------------------------------===//
1828 // YieldOp
1829 //===----------------------------------------------------------------------===//
1830 
1831 LogicalResult shape::YieldOp::verify() {
1832  auto *parentOp = (*this)->getParentOp();
1833  auto results = parentOp->getResults();
1834  auto operands = getOperands();
1835 
1836  if (parentOp->getNumResults() != getNumOperands())
1837  return emitOpError() << "number of operands does not match number of "
1838  "results of its parent";
1839  for (auto e : llvm::zip(results, operands))
1840  if (std::get<0>(e).getType() != std::get<1>(e).getType())
1841  return emitOpError() << "types mismatch between yield op and its parent";
1842 
1843  return success();
1844 }
1845 
1846 //===----------------------------------------------------------------------===//
1847 // SplitAtOp
1848 //===----------------------------------------------------------------------===//
1849 
1850 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1851  SmallVectorImpl<OpFoldResult> &results) {
1852  if (!adaptor.getOperand() || !adaptor.getIndex())
1853  return failure();
1854  auto shapeVec = llvm::to_vector<6>(
1855  llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1856  auto shape = llvm::ArrayRef(shapeVec);
1857  auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1858  // Verify that the split point is in the correct range.
1859  // TODO: Constant fold to an "error".
1860  int64_t rank = shape.size();
1861  if (-rank > splitPoint || splitPoint > rank)
1862  return failure();
1863  if (splitPoint < 0)
1864  splitPoint += shape.size();
1865  Builder builder(adaptor.getOperand().getContext());
1866  results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1867  results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1868  return success();
1869 }
1870 
1871 //===----------------------------------------------------------------------===//
1872 // ToExtentTensorOp
1873 //===----------------------------------------------------------------------===//
1874 
1875 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1876  if (!adaptor.getInput())
1877  return OpFoldResult();
1878  Builder builder(getContext());
1879  auto shape = llvm::to_vector<6>(
1880  llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1881  auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1882  builder.getIndexType());
1883  return DenseIntElementsAttr::get(type, shape);
1884 }
1885 
1886 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1887  if (inputs.size() != 1 || outputs.size() != 1)
1888  return false;
1889  if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1890  if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1891  inputTensor.getRank() != 1)
1892  return false;
1893  } else if (!llvm::isa<ShapeType>(inputs[0])) {
1894  return false;
1895  }
1896 
1897  TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1898  return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1899 }
1900 
1901 //===----------------------------------------------------------------------===//
1902 // ReduceOp
1903 //===----------------------------------------------------------------------===//
1904 
1905 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1906  ValueRange initVals) {
1907  OpBuilder::InsertionGuard g(builder);
1908  result.addOperands(shape);
1909  result.addOperands(initVals);
1910 
1911  Region *bodyRegion = result.addRegion();
1912  Block *bodyBlock = builder.createBlock(
1913  bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
1914 
1915  Type elementType;
1916  if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1917  elementType = tensorType.getElementType();
1918  else
1919  elementType = SizeType::get(builder.getContext());
1920  bodyBlock->addArgument(elementType, shape.getLoc());
1921 
1922  for (Value initVal : initVals) {
1923  bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
1924  result.addTypes(initVal.getType());
1925  }
1926 }
1927 
1928 LogicalResult ReduceOp::verify() {
1929  // Verify block arg types.
1930  Block &block = getRegion().front();
1931 
1932  // The block takes index, extent, and aggregated values as arguments.
1933  auto blockArgsCount = getInitVals().size() + 2;
1934  if (block.getNumArguments() != blockArgsCount)
1935  return emitOpError() << "ReduceOp body is expected to have "
1936  << blockArgsCount << " arguments";
1937 
1938  // The first block argument is the index and must always be of type `index`.
1939  if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1940  return emitOpError(
1941  "argument 0 of ReduceOp body is expected to be of IndexType");
1942 
1943  // The second block argument is the extent and must be of type `size` or
1944  // `index`, depending on whether the reduce operation is applied to a shape or
1945  // to an extent tensor.
1946  Type extentTy = block.getArgument(1).getType();
1947  if (llvm::isa<ShapeType>(getShape().getType())) {
1948  if (!llvm::isa<SizeType>(extentTy))
1949  return emitOpError("argument 1 of ReduceOp body is expected to be of "
1950  "SizeType if the ReduceOp operates on a ShapeType");
1951  } else {
1952  if (!llvm::isa<IndexType>(extentTy))
1953  return emitOpError(
1954  "argument 1 of ReduceOp body is expected to be of IndexType if the "
1955  "ReduceOp operates on an extent tensor");
1956  }
1957 
1958  for (const auto &type : llvm::enumerate(getInitVals()))
1959  if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1960  return emitOpError() << "type mismatch between argument "
1961  << type.index() + 2
1962  << " of ReduceOp body and initial value "
1963  << type.index();
1964  return success();
1965 }
1966 
1967 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1968  // Parse operands.
1970  Type shapeOrExtentTensorType;
1971  if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1973  parser.parseColonType(shapeOrExtentTensorType) ||
1974  parser.parseOptionalArrowTypeList(result.types))
1975  return failure();
1976 
1977  // Resolve operands.
1978  auto initVals = llvm::ArrayRef(operands).drop_front();
1979  if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1980  result.operands) ||
1981  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1982  result.operands))
1983  return failure();
1984 
1985  // Parse the body.
1986  Region *body = result.addRegion();
1987  if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1988  return failure();
1989 
1990  // Parse attributes.
1991  if (parser.parseOptionalAttrDict(result.attributes))
1992  return failure();
1993 
1994  return success();
1995 }
1996 
1997 void ReduceOp::print(OpAsmPrinter &p) {
1998  p << '(' << getShape() << ", " << getInitVals()
1999  << ") : " << getShape().getType();
2000  p.printOptionalArrowTypeList(getResultTypes());
2001  p << ' ';
2002  p.printRegion(getRegion());
2003  p.printOptionalAttrDict((*this)->getAttrs());
2004 }
2005 
2006 #define GET_OP_CLASSES
2007 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2008 
2009 #define GET_TYPEDEF_CLASSES
2010 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition: Shape.cpp:67
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition: Shape.cpp:962
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition: Shape.cpp:84
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition: Shape.cpp:97
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition: Shape.cpp:72
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1545
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
Operation & front()
Definition: Block.h:151
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:209
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:110
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
operand_iterator operand_begin()
Definition: Operation.h:369
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_iterator operand_end()
Definition: Operation.h:370
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
ValueTypeRange< ValueRange > type_range
Definition: ValueRange.h:410
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition: Traits.cpp:25
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
bool isExtentTensorType(Type)
Definition: Shape.cpp:45
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition: Shape.cpp:50
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:41
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:490
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.