MLIR  20.0.0git
Shape.cpp
Go to the documentation of this file.
1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 
17 #include "mlir/Dialect/Traits.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/SetOperations.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 using namespace mlir;
33 using namespace mlir::shape;
34 
35 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
36 
37 namespace {
38 #include "ShapeCanonicalization.inc"
39 } // namespace
40 
41 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
42  return RankedTensorType::get({rank}, IndexType::get(ctx));
43 }
44 
46  auto ranked = llvm::dyn_cast<RankedTensorType>(type);
47  return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
48 }
49 
50 LogicalResult shape::getShapeVec(Value input,
51  SmallVectorImpl<int64_t> &shapeValues) {
52  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
53  auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
54  if (!type.hasRank())
55  return failure();
56  llvm::append_range(shapeValues, type.getShape());
57  return success();
58  }
60  if (matchPattern(input, m_Constant(&attr))) {
61  llvm::append_range(shapeValues, attr.getValues<int64_t>());
62  return success();
63  }
64  return failure();
65 }
66 
67 static bool isErrorPropagationPossible(TypeRange operandTypes) {
68  return llvm::any_of(operandTypes,
69  llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
70 }
71 
72 static LogicalResult verifySizeOrIndexOp(Operation *op) {
73  assert(op != nullptr && op->getNumResults() == 1);
74  Type resultTy = op->getResultTypes().front();
76  if (!llvm::isa<SizeType>(resultTy))
77  return op->emitOpError()
78  << "if at least one of the operands can hold error values then "
79  "the result must be of type `size` to propagate them";
80  }
81  return success();
82 }
83 
84 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
85  assert(op != nullptr && op->getNumResults() == 1);
86  Type resultTy = op->getResultTypes().front();
88  if (!llvm::isa<ShapeType>(resultTy))
89  return op->emitOpError()
90  << "if at least one of the operands can hold error values then "
91  "the result must be of type `shape` to propagate them";
92  }
93  return success();
94 }
95 
96 template <typename... Ty>
97 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
98  return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
99 }
100 
101 template <typename... Ty, typename... ranges>
102 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
103  return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // InlinerInterface
108 //===----------------------------------------------------------------------===//
109 
110 namespace {
111 /// This class defines the interface for inlining shape dialect ops.
112 struct ShapeInlinerInterface : public DialectInlinerInterface {
114 
115  // Returns true if the given region 'src' can be inlined into the region
116  // 'dest' that is attached to an operation registered to the current dialect.
117  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
118  IRMapping &) const final {
119  return true;
120  }
121 
122  // Returns true if the given operation 'op', that is registered to this
123  // dialect, can be inlined into the region 'dest' that is attached to an
124  // operation registered to the current dialect.
125  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
126  IRMapping &) const final {
127  return true;
128  }
129 };
130 } // namespace
131 
132 void ShapeDialect::initialize() {
133  addOperations<
134 #define GET_OP_LIST
135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
136  >();
137  addTypes<
138 #define GET_TYPEDEF_LIST
139 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
140  >();
141  addInterfaces<ShapeInlinerInterface>();
142  // Allow unknown operations during prototyping and testing. As the dialect is
143  // still evolving it makes it simple to start with an unregistered ops and
144  // try different variants before actually defining the op.
145  allowUnknownOperations();
146  declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
147  AssumingYieldOp>();
148 }
149 
151  Attribute value, Type type,
152  Location loc) {
153  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
154  return builder.create<ub::PoisonOp>(loc, type, poison);
155 
156  if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
157  return builder.create<ConstShapeOp>(
158  loc, type, llvm::cast<DenseIntElementsAttr>(value));
159  if (llvm::isa<SizeType>(type))
160  return builder.create<ConstSizeOp>(loc, type,
161  llvm::cast<IntegerAttr>(value));
162  if (llvm::isa<WitnessType>(type))
163  return builder.create<ConstWitnessOp>(loc, type,
164  llvm::cast<BoolAttr>(value));
165 
166  return arith::ConstantOp::materialize(builder, value, type, loc);
167 }
168 
169 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
170  NamedAttribute attribute) {
171  // Verify shape.lib attribute.
172  if (attribute.getName() == "shape.lib") {
173  if (!op->hasTrait<OpTrait::SymbolTable>())
174  return op->emitError(
175  "shape.lib attribute may only be on op implementing SymbolTable");
176 
177  if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
178  auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
179  if (!symbol)
180  return op->emitError("shape function library ")
181  << symbolRef << " not found";
182  return isa<shape::FunctionLibraryOp>(symbol)
183  ? success()
184  : op->emitError()
185  << symbolRef << " required to be shape function library";
186  }
187 
188  if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
189  // Verify all entries are function libraries and mappings in libraries
190  // refer to unique ops.
192  for (auto it : arr) {
193  if (!llvm::isa<SymbolRefAttr>(it))
194  return op->emitError(
195  "only SymbolRefAttr allowed in shape.lib attribute array");
196 
197  auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
198  SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
199  if (!shapeFnLib)
200  return op->emitError()
201  << it << " does not refer to FunctionLibraryOp";
202  for (auto mapping : shapeFnLib.getMapping()) {
203  if (!key.insert(mapping.getName()).second) {
204  return op->emitError("only one op to shape mapping allowed, found "
205  "multiple for `")
206  << mapping.getName() << "`";
207  }
208  }
209  }
210  return success();
211  }
212 
213  return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
214  "allowed as shape.lib attribute");
215  }
216  return success();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // AnyOp
221 //===----------------------------------------------------------------------===//
222 
223 // TODO: Canonicalization should be implemented for shapes that can be
224 // determined through mixtures of the known dimensions of the inputs.
225 OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
226  // Only the last operand is checked because AnyOp is commutative.
227  if (adaptor.getInputs().back())
228  return adaptor.getInputs().back();
229 
230  return nullptr;
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // AssumingOp
235 //===----------------------------------------------------------------------===//
236 
237 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
238  result.regions.reserve(1);
239  Region *doRegion = result.addRegion();
240 
241  auto &builder = parser.getBuilder();
243  if (parser.parseOperand(cond) ||
244  parser.resolveOperand(cond, builder.getType<WitnessType>(),
245  result.operands))
246  return failure();
247 
248  // Parse optional results type list.
249  if (parser.parseOptionalArrowTypeList(result.types))
250  return failure();
251 
252  // Parse the region and add a terminator if elided.
253  if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
254  return failure();
255  AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
256 
257  // Parse the optional attribute list.
258  if (parser.parseOptionalAttrDict(result.attributes))
259  return failure();
260  return success();
261 }
262 
264  bool yieldsResults = !getResults().empty();
265 
266  p << " " << getWitness();
267  if (yieldsResults)
268  p << " -> (" << getResultTypes() << ")";
269  p << ' ';
270  p.printRegion(getDoRegion(),
271  /*printEntryBlockArgs=*/false,
272  /*printBlockTerminators=*/yieldsResults);
273  p.printOptionalAttrDict((*this)->getAttrs());
274 }
275 
276 namespace {
277 // Removes AssumingOp with a passing witness and inlines the region.
278 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
280 
281  LogicalResult matchAndRewrite(AssumingOp op,
282  PatternRewriter &rewriter) const override {
283  auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
284  if (!witness || !witness.getPassingAttr())
285  return failure();
286 
287  AssumingOp::inlineRegionIntoParent(op, rewriter);
288  return success();
289  }
290 };
291 
292 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
294 
295  LogicalResult matchAndRewrite(AssumingOp op,
296  PatternRewriter &rewriter) const override {
297  Block *body = op.getBody();
298  auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
299 
300  // Find used values.
301  SmallVector<Value, 4> newYieldOperands;
302  for (auto [opResult, yieldOperand] :
303  llvm::zip(op.getResults(), yieldOp.getOperands())) {
304  if (!opResult.getUses().empty()) {
305  newYieldOperands.push_back(yieldOperand);
306  }
307  }
308 
309  // Rewrite only if redundant results exist.
310  if (newYieldOperands.size() == yieldOp->getNumOperands())
311  return failure();
312 
313  // Replace yield op in the old assuming op's body and move the entire region
314  // to the new assuming op.
315  rewriter.setInsertionPointToEnd(body);
316  auto newYieldOp =
317  rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
318  rewriter.setInsertionPoint(op);
319  auto newOp = rewriter.create<AssumingOp>(
320  op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
321  newOp.getDoRegion().takeBody(op.getDoRegion());
322 
323  // Use the new results to replace the previously used ones.
324  SmallVector<Value, 4> replacementValues;
325  auto src = newOp.getResults().begin();
326  for (auto it : op.getResults()) {
327  if (it.getUses().empty())
328  replacementValues.push_back(nullptr);
329  else
330  replacementValues.push_back(*src++);
331  }
332  rewriter.replaceOp(op, replacementValues);
333  return success();
334  }
335 };
336 } // namespace
337 
338 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
339  MLIRContext *context) {
340  patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
341 }
342 
343 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
344 void AssumingOp::getSuccessorRegions(
346  // AssumingOp has unconditional control flow into the region and back to the
347  // parent, so return the correct RegionSuccessor purely based on the index
348  // being None or 0.
349  if (!point.isParent()) {
350  regions.push_back(RegionSuccessor(getResults()));
351  return;
352  }
353 
354  regions.push_back(RegionSuccessor(&getDoRegion()));
355 }
356 
357 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
358  PatternRewriter &rewriter) {
359  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
360  auto *assumingBlock = op.getBody();
361  auto initPosition = rewriter.getInsertionPoint();
362  auto *blockAfterAssuming =
363  rewriter.splitBlock(blockBeforeAssuming, initPosition);
364 
365  // Remove the AssumingOp and AssumingYieldOp.
366  auto &yieldOp = assumingBlock->back();
367  rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
368  rewriter.replaceOp(op, yieldOp.getOperands());
369  rewriter.eraseOp(&yieldOp);
370 
371  // Merge blocks together as there was no branching behavior from the
372  // AssumingOp.
373  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
374  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
375 }
376 
377 void AssumingOp::build(
378  OpBuilder &builder, OperationState &result, Value witness,
380  OpBuilder::InsertionGuard g(builder);
381 
382  result.addOperands(witness);
383  Region *bodyRegion = result.addRegion();
384  builder.createBlock(bodyRegion);
385 
386  // Build body.
387  SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
388  builder.create<AssumingYieldOp>(result.location, yieldValues);
389 
390  SmallVector<Type, 2> assumingTypes;
391  for (Value v : yieldValues)
392  assumingTypes.push_back(v.getType());
393  result.addTypes(assumingTypes);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // AddOp
398 //===----------------------------------------------------------------------===//
399 
400 LogicalResult mlir::shape::AddOp::inferReturnTypes(
401  MLIRContext *context, std::optional<Location> location,
402  AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
403  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
404  llvm::isa<SizeType>(adaptor.getRhs().getType()))
405  inferredReturnTypes.assign({SizeType::get(context)});
406  else
407  inferredReturnTypes.assign({IndexType::get(context)});
408  return success();
409 }
410 
411 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
412  // SizeType is compatible with IndexType.
413  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
414 }
415 
416 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
417  // add(x, 0) -> x
418  if (matchPattern(getRhs(), m_Zero()))
419  return getLhs();
420 
421  return constFoldBinaryOp<IntegerAttr>(
422  adaptor.getOperands(),
423  [](APInt a, const APInt &b) { return std::move(a) + b; });
424 }
425 
426 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
427 
428 //===----------------------------------------------------------------------===//
429 // AssumingAllOp
430 //===----------------------------------------------------------------------===//
431 
432 namespace {
433 
434 // Merge multiple `shape.assuming_all` operations together.
435 //
436 // %0 = shape.assuming_all %w0, %w1
437 // %1 = shape.assuming_all %w2, %0
438 //
439 // to:
440 //
441 // %0 = shape.assuming_all %w0, %w2, %w2
442 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
444 
445  LogicalResult matchAndRewrite(AssumingAllOp op,
446  PatternRewriter &rewriter) const override {
447  SmallVector<Value> operands;
448 
449  for (Value operand : op.getInputs()) {
450  if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
451  operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
452  else
453  operands.push_back(operand);
454  }
455 
456  // We didn't find any other `assuming_all` ops to merge with.
457  if (operands.size() == op.getNumOperands())
458  return failure();
459 
460  // Replace with a new `assuming_all` operation with merged constraints.
461  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
462  return success();
463  }
464 };
465 
466 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
467 // are subsumed by others.
468 //
469 // %0 = shape.cstr_broadcastable %shape0, %shape1
470 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
471 //
472 // %2 = shape.cstr_broadcastable %shape3, %shape4
473 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
474 //
475 // %4 = shape.assuming_all %0, %1, %2, %3
476 //
477 // to:
478 //
479 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
480 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
481 // %2 = shape.assuming_all %0, %1
482 //
483 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
484 // shapes [0, 1] are broadcastable too, and can be removed from the list of
485 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
486 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
487 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
489 
490  LogicalResult matchAndRewrite(AssumingAllOp op,
491  PatternRewriter &rewriter) const override {
492  // Collect all `CstrBroadcastableOp` operands first.
494  for (Value operand : op.getInputs()) {
495  // TODO: Apply this optimization if some of the witnesses are not
496  // produced by the `cstr_broadcastable`.
497  auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
498  if (!broadcastable)
499  return failure();
500 
501  operands.insert(broadcastable);
502  }
503 
504  // Skip trivial `assuming_all` operations.
505  if (operands.size() <= 1)
506  return failure();
507 
508  // Collect shapes checked by `cstr_broadcastable` operands.
510  for (auto cstr : operands) {
511  DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
512  shapes.emplace_back(cstr, std::move(shapesSet));
513  }
514 
515  // Sort by the number of shape operands (larger to smaller).
516  llvm::sort(shapes, [](auto a, auto b) {
517  return a.first.getNumOperands() > b.first.getNumOperands();
518  });
519 
520  // We start from the `cst_broadcastable` operations with largest number of
521  // shape operands, and remove redundant `cst_broadcastable` operations. We
522  // do this until we find a set of `cst_broadcastable` operations with
523  // non-overlapping constraints.
524  SmallVector<CstrBroadcastableOp> markedForErase;
525 
526  for (unsigned i = 0; i < shapes.size(); ++i) {
527  auto isSubset = [&](auto pair) {
528  return llvm::set_is_subset(pair.second, shapes[i].second);
529  };
530 
531  // Keep redundant `cstr_broadcastable` operations to be erased.
532  auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
533  for (auto *it0 = it; it0 < shapes.end(); ++it0)
534  markedForErase.push_back(it0->first);
535  shapes.erase(it, shapes.end());
536  }
537 
538  // We didn't find any operands that could be removed.
539  if (markedForErase.empty())
540  return failure();
541 
542  // Collect non-overlapping `cst_broadcastable` constraints.
543  SmallVector<Value> uniqueConstraints;
544  for (auto &shape : shapes)
545  uniqueConstraints.push_back(shape.first.getResult());
546 
547  // Replace with a new `assuming_all` operation ...
548  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
549 
550  // ... and maybe erase `cstr_broadcastable` ops without uses.
551  for (auto &op : markedForErase)
552  if (op->use_empty())
553  rewriter.eraseOp(op);
554 
555  return success();
556  }
557 };
558 
559 struct AssumingAllToCstrEqCanonicalization
560  : public OpRewritePattern<AssumingAllOp> {
562 
563  LogicalResult matchAndRewrite(AssumingAllOp op,
564  PatternRewriter &rewriter) const override {
565  SmallVector<Value, 8> shapes;
566  for (Value w : op.getInputs()) {
567  auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
568  if (!cstrEqOp)
569  return failure();
570  bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
571  return llvm::is_contained(shapes, s);
572  });
573  if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
574  return failure();
575  shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
576  }
577  rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
578  return success();
579  }
580 };
581 
582 template <typename OpTy>
583 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
585 
586  LogicalResult matchAndRewrite(OpTy op,
587  PatternRewriter &rewriter) const override {
588  // Find unique operands.
589  SetVector<Value> unique(op.operand_begin(), op.operand_end());
590 
591  // Reduce op to equivalent with unique operands.
592  if (unique.size() < op.getNumOperands()) {
593  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
594  unique.takeVector(), op->getAttrs());
595  return success();
596  }
597 
598  return failure();
599  }
600 };
601 } // namespace
602 
603 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
604  MLIRContext *context) {
605  patterns
606  .add<MergeAssumingAllOps, AssumingAllOneOp,
607  AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
608  RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
609 }
610 
611 OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
612  // Iterate in reverse to first handle all constant operands. They are
613  // guaranteed to be the tail of the inputs because this is commutative.
614  for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
615  Attribute a = adaptor.getInputs()[idx];
616  // Cannot fold if any inputs are not constant;
617  if (!a)
618  return nullptr;
619 
620  // We do not need to keep statically known values after handling them in
621  // this method.
622  getOperation()->eraseOperand(idx);
623 
624  // Always false if any input is statically known false
625  if (!llvm::cast<BoolAttr>(a).getValue())
626  return a;
627  }
628  // If this is reached, all inputs were statically known passing.
629  return BoolAttr::get(getContext(), true);
630 }
631 
632 LogicalResult AssumingAllOp::verify() {
633  // Ensure that AssumingAllOp contains at least one operand
634  if (getNumOperands() == 0)
635  return emitOpError("no operands specified");
636 
637  return success();
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // BroadcastOp
642 //===----------------------------------------------------------------------===//
643 
644 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
645  if (getShapes().size() == 1) {
646  // Otherwise, we need a cast which would be a canonicalization, not folding.
647  if (getShapes().front().getType() != getType())
648  return nullptr;
649  return getShapes().front();
650  }
651 
652  // TODO: Support folding with more than 2 input shapes
653  if (getShapes().size() > 2)
654  return nullptr;
655 
656  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
657  return nullptr;
658  auto lhsShape = llvm::to_vector<6>(
659  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
660  .getValues<int64_t>());
661  auto rhsShape = llvm::to_vector<6>(
662  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
663  .getValues<int64_t>());
664  SmallVector<int64_t, 6> resultShape;
665 
666  // If the shapes are not compatible, we can't fold it.
667  // TODO: Fold to an "error".
668  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
669  return nullptr;
670 
671  Builder builder(getContext());
672  return builder.getIndexTensorAttr(resultShape);
673 }
674 
675 LogicalResult BroadcastOp::verify() {
676  return verifyShapeOrExtentTensorOp(*this);
677 }
678 
679 namespace {
680 template <typename OpTy>
681 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
683 
684  LogicalResult matchAndRewrite(OpTy op,
685  PatternRewriter &rewriter) const override {
686  auto isPotentiallyNonEmptyShape = [](Value shape) {
687  if (auto extentTensorTy =
688  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
689  if (extentTensorTy.getDimSize(0) == 0)
690  return false;
691  }
692  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
693  if (constShape.getShape().empty())
694  return false;
695  }
696  return true;
697  };
698  auto newOperands = llvm::to_vector<8>(
699  llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
700 
701  // Reduce op to equivalent without empty shape operands.
702  if (newOperands.size() < op.getNumOperands()) {
703  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
704  op->getAttrs());
705  return success();
706  }
707 
708  return failure();
709  }
710 };
711 
712 struct BroadcastForwardSingleOperandPattern
713  : public OpRewritePattern<BroadcastOp> {
715 
716  LogicalResult matchAndRewrite(BroadcastOp op,
717  PatternRewriter &rewriter) const override {
718  if (op.getNumOperands() != 1)
719  return failure();
720  Value replacement = op.getShapes().front();
721 
722  // Insert cast if needed.
723  if (replacement.getType() != op.getType()) {
724  auto loc = op.getLoc();
725  if (llvm::isa<ShapeType>(op.getType())) {
726  replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
727  } else {
728  assert(!llvm::isa<ShapeType>(op.getType()) &&
729  !llvm::isa<ShapeType>(replacement.getType()) &&
730  "expect extent tensor cast");
731  replacement =
732  rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
733  }
734  }
735 
736  rewriter.replaceOp(op, replacement);
737  return success();
738  }
739 };
740 
741 struct BroadcastFoldConstantOperandsPattern
742  : public OpRewritePattern<BroadcastOp> {
744 
745  LogicalResult matchAndRewrite(BroadcastOp op,
746  PatternRewriter &rewriter) const override {
747  SmallVector<int64_t, 8> foldedConstantShape;
748  SmallVector<Value, 8> newShapeOperands;
749  for (Value shape : op.getShapes()) {
750  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
751  SmallVector<int64_t, 8> newFoldedConstantShape;
753  foldedConstantShape,
754  llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
755  newFoldedConstantShape)) {
756  foldedConstantShape = newFoldedConstantShape;
757  continue;
758  }
759  }
760  newShapeOperands.push_back(shape);
761  }
762 
763  // Need at least two constant operands to fold anything.
764  if (op.getNumOperands() - newShapeOperands.size() < 2)
765  return failure();
766 
767  auto foldedConstantOperandsTy = RankedTensorType::get(
768  {static_cast<int64_t>(foldedConstantShape.size())},
769  rewriter.getIndexType());
770  newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
771  op.getLoc(), foldedConstantOperandsTy,
772  rewriter.getIndexTensorAttr(foldedConstantShape)));
773  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
774  newShapeOperands);
775  return success();
776  }
777 };
778 
779 template <typename OpTy>
780 struct CanonicalizeCastExtentTensorOperandsPattern
781  : public OpRewritePattern<OpTy> {
783 
784  LogicalResult matchAndRewrite(OpTy op,
785  PatternRewriter &rewriter) const override {
786  // Canonicalize operands.
787  bool anyChange = false;
788  auto canonicalizeOperand = [&](Value operand) -> Value {
789  if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
790  // Only eliminate the cast if it holds no shape information.
791  bool isInformationLoosingCast =
792  llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
793  if (isInformationLoosingCast) {
794  anyChange = true;
795  return castOp.getSource();
796  }
797  }
798  return operand;
799  };
800  auto newOperands = llvm::to_vector<8>(
801  llvm::map_range(op.getOperands(), canonicalizeOperand));
802 
803  // Rewrite op if any change required.
804  if (!anyChange)
805  return failure();
806  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
807  return success();
808  }
809 };
810 
811 struct BroadcastConcretizeResultTypePattern
812  : public OpRewritePattern<BroadcastOp> {
814 
815  LogicalResult matchAndRewrite(BroadcastOp op,
816  PatternRewriter &rewriter) const override {
817  // Only concretize dynamic extent tensor result types.
818  auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
819  if (!resultTy || !resultTy.isDynamicDim(0))
820  return failure();
821 
822  // Infer resulting shape rank if possible.
823  int64_t maxRank = 0;
824  for (Value shape : op.getShapes()) {
825  if (auto extentTensorTy =
826  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
827  // Cannot infer resulting shape rank if any operand is dynamically
828  // ranked.
829  if (extentTensorTy.isDynamicDim(0))
830  return failure();
831  maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
832  }
833  }
834 
835  auto newOp = rewriter.create<BroadcastOp>(
836  op.getLoc(), getExtentTensorType(getContext(), maxRank),
837  op.getShapes());
838  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
839  return success();
840  }
841 };
842 } // namespace
843 
844 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
845  MLIRContext *context) {
846  patterns.add<BroadcastConcretizeResultTypePattern,
847  BroadcastFoldConstantOperandsPattern,
848  BroadcastForwardSingleOperandPattern,
849  CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
850  RemoveDuplicateOperandsPattern<BroadcastOp>,
851  RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
852 }
853 
854 //===----------------------------------------------------------------------===//
855 // ConcatOp
856 //===----------------------------------------------------------------------===//
857 
858 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
859  if (!adaptor.getLhs() || !adaptor.getRhs())
860  return nullptr;
861  auto lhsShape = llvm::to_vector<6>(
862  llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
863  auto rhsShape = llvm::to_vector<6>(
864  llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
865  SmallVector<int64_t, 6> resultShape;
866  resultShape.append(lhsShape.begin(), lhsShape.end());
867  resultShape.append(rhsShape.begin(), rhsShape.end());
868  Builder builder(getContext());
869  return builder.getIndexTensorAttr(resultShape);
870 }
871 
872 //===----------------------------------------------------------------------===//
873 // ConstShapeOp
874 //===----------------------------------------------------------------------===//
875 
877  p << " ";
878  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
879  p << "[";
880  interleaveComma(getShape().getValues<int64_t>(), p);
881  p << "] : ";
882  p.printType(getType());
883 }
884 
885 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
886  if (parser.parseOptionalAttrDict(result.attributes))
887  return failure();
888  // We piggy-back on ArrayAttr parsing, though we don't internally store the
889  // shape as an ArrayAttr.
890  // TODO: Implement custom parser and maybe make syntax a bit more concise.
891  Attribute extentsRaw;
892  NamedAttrList dummy;
893  if (parser.parseAttribute(extentsRaw, "dummy", dummy))
894  return failure();
895  auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
896  if (!extentsArray)
897  return failure();
899  for (Attribute extent : extentsArray) {
900  IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
901  if (!attr)
902  return failure();
903  ints.push_back(attr.getInt());
904  }
905  Builder &builder = parser.getBuilder();
906  result.addAttribute("shape", builder.getIndexTensorAttr(ints));
907  Type resultTy;
908  if (parser.parseColonType(resultTy))
909  return failure();
910  result.types.push_back(resultTy);
911  return success();
912 }
913 
914 OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
915 
916 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
917  MLIRContext *context) {
918  patterns.add<TensorCastConstShape>(context);
919 }
920 
921 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
922  MLIRContext *context, std::optional<Location> location,
923  ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
924  Builder b(context);
925  const Properties prop = adaptor.getProperties();
926  inferredReturnTypes.assign({RankedTensorType::get(
927  {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
928  return success();
929 }
930 
931 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
932  TypeRange r) {
933  if (l.size() != 1 || r.size() != 1)
934  return false;
935 
936  Type lhs = l.front();
937  Type rhs = r.front();
938 
939  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
940  // Shape type is compatible with all other valid return types.
941  return true;
942  return lhs == rhs;
943 }
944 
945 //===----------------------------------------------------------------------===//
946 // CstrBroadcastableOp
947 //===----------------------------------------------------------------------===//
948 
949 void CstrBroadcastableOp::getCanonicalizationPatterns(
950  RewritePatternSet &patterns, MLIRContext *context) {
951  // Canonicalization patterns have overlap with the considerations during
952  // folding in case additional shape information is inferred at some point that
953  // does not result in folding.
954  patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
955  CstrBroadcastableEqOps,
956  RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
957  RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
958 }
959 
960 // Return true if there is exactly one attribute not representing a scalar
961 // broadcast.
963  bool nonScalarSeen = false;
964  for (Attribute a : attributes) {
965  if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
966  if (nonScalarSeen)
967  return false;
968  nonScalarSeen = true;
969  }
970  }
971  return true;
972 }
973 
974 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
975  // No broadcasting is needed if all operands but one are scalar.
976  if (hasAtMostSingleNonScalar(adaptor.getShapes()))
977  return BoolAttr::get(getContext(), true);
978 
979  if ([&] {
981  for (const auto &operand : adaptor.getShapes()) {
982  if (!operand)
983  return false;
984  extents.push_back(llvm::to_vector<6>(
985  llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
986  }
988  }())
989  return BoolAttr::get(getContext(), true);
990 
991  // Lastly, see if folding can be completed based on what constraints are known
992  // on the input shapes.
993  if ([&] {
995  for (auto shapeValue : getShapes()) {
996  extents.emplace_back();
997  if (failed(getShapeVec(shapeValue, extents.back())))
998  return false;
999  }
1001  }())
1002  return BoolAttr::get(getContext(), true);
1003 
1004  // Because a failing witness result here represents an eventual assertion
1005  // failure, we do not replace it with a constant witness.
1006  return nullptr;
1007 }
1008 
1009 LogicalResult CstrBroadcastableOp::verify() {
1010  // Ensure that CstrBroadcastableOp contains at least two operands
1011  if (getNumOperands() < 2)
1012  return emitOpError("required at least 2 input shapes");
1013  return success();
1014 }
1015 
1016 //===----------------------------------------------------------------------===//
1017 // CstrEqOp
1018 //===----------------------------------------------------------------------===//
1019 
1020 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1021  MLIRContext *context) {
1022  // If inputs are equal, return passing witness
1023  patterns.add<CstrEqEqOps>(context);
1024 }
1025 
1026 OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1027  if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1028  return a && a == adaptor.getShapes().front();
1029  }))
1030  return BoolAttr::get(getContext(), true);
1031 
1032  // Because a failing witness result here represents an eventual assertion
1033  // failure, we do not try to replace it with a constant witness. Similarly, we
1034  // cannot if there are any non-const inputs.
1035  return nullptr;
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // ConstSizeOp
1040 //===----------------------------------------------------------------------===//
1041 
1042 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1043  int64_t value) {
1044  build(builder, result, builder.getIndexAttr(value));
1045 }
1046 
1047 OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1048 
1049 void ConstSizeOp::getAsmResultNames(
1050  llvm::function_ref<void(Value, StringRef)> setNameFn) {
1051  SmallString<4> buffer;
1052  llvm::raw_svector_ostream os(buffer);
1053  os << "c" << getValue();
1054  setNameFn(getResult(), os.str());
1055 }
1056 
1057 //===----------------------------------------------------------------------===//
1058 // ConstWitnessOp
1059 //===----------------------------------------------------------------------===//
1060 
1061 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1062 
1063 //===----------------------------------------------------------------------===//
1064 // CstrRequireOp
1065 //===----------------------------------------------------------------------===//
1066 
1067 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1068  return adaptor.getPred();
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // DimOp
1073 //===----------------------------------------------------------------------===//
1074 
1075 std::optional<int64_t> DimOp::getConstantIndex() {
1076  if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1077  return constSizeOp.getValue().getLimitedValue();
1078  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1079  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1080  return std::nullopt;
1081 }
1082 
1083 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1084  Type valType = getValue().getType();
1085  auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1086  if (!valShapedType || !valShapedType.hasRank())
1087  return nullptr;
1088  std::optional<int64_t> index = getConstantIndex();
1089  if (!index.has_value())
1090  return nullptr;
1091  if (index.value() < 0 || index.value() >= valShapedType.getRank())
1092  return nullptr;
1093  auto extent = valShapedType.getDimSize(*index);
1094  if (ShapedType::isDynamic(extent))
1095  return nullptr;
1096  return IntegerAttr::get(IndexType::get(getContext()), extent);
1097 }
1098 
1099 LogicalResult mlir::shape::DimOp::inferReturnTypes(
1100  MLIRContext *context, std::optional<Location> location,
1101  DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1102  inferredReturnTypes.assign({adaptor.getIndex().getType()});
1103  return success();
1104 }
1105 
1106 bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1107  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1108 }
1109 
1110 //===----------------------------------------------------------------------===//
1111 // DivOp
1112 //===----------------------------------------------------------------------===//
1113 
1114 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1115  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1116  if (!lhs)
1117  return nullptr;
1118  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1119  if (!rhs)
1120  return nullptr;
1121 
1122  // Division in APInt does not follow floor(lhs, rhs) when the result is
1123  // negative. Rather, APInt rounds toward zero.
1124  APInt quotient, remainder;
1125  APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1126  if (quotient.isNegative() && !remainder.isZero()) {
1127  quotient -= 1;
1128  }
1129 
1130  Type indexTy = IndexType::get(getContext());
1131  return IntegerAttr::get(indexTy, quotient);
1132 }
1133 
1134 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1135  MLIRContext *context, std::optional<Location> location,
1136  DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1137  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1138  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1139  inferredReturnTypes.assign({SizeType::get(context)});
1140  else
1141  inferredReturnTypes.assign({IndexType::get(context)});
1142  return success();
1143 }
1144 
1145 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1146  // SizeType is compatible with IndexType.
1147  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1148 }
1149 
1150 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1151 
1152 //===----------------------------------------------------------------------===//
1153 // ShapeEqOp
1154 //===----------------------------------------------------------------------===//
1155 
1156 OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1157  bool allSame = true;
1158  if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1159  return {};
1160  for (Attribute operand : adaptor.getShapes().drop_front()) {
1161  if (!operand)
1162  return {};
1163  allSame = allSame && operand == adaptor.getShapes().front();
1164  }
1165  return BoolAttr::get(getContext(), allSame);
1166 }
1167 
1168 //===----------------------------------------------------------------------===//
1169 // IndexToSizeOp
1170 //===----------------------------------------------------------------------===//
1171 
1172 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1173  // Constant values of both types, `shape.size` and `index`, are represented as
1174  // `IntegerAttr`s which makes constant folding simple.
1175  if (Attribute arg = adaptor.getArg())
1176  return arg;
1177  return {};
1178 }
1179 
1180 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1181  MLIRContext *context) {
1182  patterns.add<SizeToIndexToSizeCanonicalization>(context);
1183 }
1184 
1185 //===----------------------------------------------------------------------===//
1186 // FromExtentsOp
1187 //===----------------------------------------------------------------------===//
1188 
1189 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1190  if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
1191  return nullptr;
1192  SmallVector<int64_t, 6> extents;
1193  for (auto attr : adaptor.getExtents())
1194  extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1195  Builder builder(getContext());
1196  return builder.getIndexTensorAttr(extents);
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // FunctionLibraryOp
1201 //===----------------------------------------------------------------------===//
1202 
1203 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1204  StringRef name) {
1205  result.attributes.push_back(builder.getNamedAttr(
1207 }
1208 
1209 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1210  auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1211  getMapping().get(op->getName().getIdentifier()));
1212  if (!attr)
1213  return nullptr;
1214  return lookupSymbol<FuncOp>(attr);
1215 }
1216 
1217 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1218  OperationState &result) {
1219  // Parse the op name.
1220  StringAttr nameAttr;
1221  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1222  result.attributes))
1223  return failure();
1224 
1225  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1226  return failure();
1227 
1228  auto *bodyRegion = result.addRegion();
1229  if (parser.parseRegion(*bodyRegion))
1230  return failure();
1231 
1232  if (parser.parseKeyword("mapping"))
1233  return failure();
1234 
1235  DictionaryAttr mappingAttr;
1236  if (parser.parseAttribute(mappingAttr,
1237  parser.getBuilder().getType<NoneType>(), "mapping",
1238  result.attributes))
1239  return failure();
1240  return success();
1241 }
1242 
1244  p << ' ';
1245  p.printSymbolName(getName());
1247  (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1248  p << ' ';
1249  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1250  /*printBlockTerminators=*/false);
1251  p << " mapping ";
1252  p.printAttributeWithoutType(getMappingAttr());
1253 }
1254 
1255 //===----------------------------------------------------------------------===//
1256 // FuncOp
1257 //===----------------------------------------------------------------------===//
1258 
1259 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1260  ArrayRef<NamedAttribute> attrs) {
1261  OpBuilder builder(location->getContext());
1262  OperationState state(location, getOperationName());
1263  FuncOp::build(builder, state, name, type, attrs);
1264  return cast<FuncOp>(Operation::create(state));
1265 }
1266 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1268  SmallVector<NamedAttribute, 8> attrRef(attrs);
1269  return create(location, name, type, llvm::ArrayRef(attrRef));
1270 }
1271 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1273  ArrayRef<DictionaryAttr> argAttrs) {
1274  FuncOp func = create(location, name, type, attrs);
1275  func.setAllArgAttrs(argAttrs);
1276  return func;
1277 }
1278 
1279 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1280  FunctionType type, ArrayRef<NamedAttribute> attrs,
1281  ArrayRef<DictionaryAttr> argAttrs) {
1282  state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1283  builder.getStringAttr(name));
1284  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1285  TypeAttr::get(type));
1286  state.attributes.append(attrs.begin(), attrs.end());
1287  state.addRegion();
1288 
1289  if (argAttrs.empty())
1290  return;
1291  assert(type.getNumInputs() == argAttrs.size());
1293  builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
1294  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1295 }
1296 
1297 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1298  auto buildFuncType =
1299  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1301  std::string &) { return builder.getFunctionType(argTypes, results); };
1302 
1304  parser, result, /*allowVariadic=*/false,
1305  getFunctionTypeAttrName(result.name), buildFuncType,
1306  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1307 }
1308 
1309 void FuncOp::print(OpAsmPrinter &p) {
1311  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1312  getArgAttrsAttrName(), getResAttrsAttrName());
1313 }
1314 
1315 //===----------------------------------------------------------------------===//
1316 // GetExtentOp
1317 //===----------------------------------------------------------------------===//
1318 
1319 std::optional<int64_t> GetExtentOp::getConstantDim() {
1320  if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1321  return constSizeOp.getValue().getLimitedValue();
1322  if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1323  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1324  return std::nullopt;
1325 }
1326 
1327 OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1328  auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1329  if (!elements)
1330  return nullptr;
1331  std::optional<int64_t> dim = getConstantDim();
1332  if (!dim.has_value())
1333  return nullptr;
1334  if (dim.value() >= elements.getNumElements())
1335  return nullptr;
1336  return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1337 }
1338 
1339 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1340  int64_t dim) {
1341  auto loc = result.location;
1342  auto dimAttr = builder.getIndexAttr(dim);
1343  if (llvm::isa<ShapeType>(shape.getType())) {
1344  Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1345  build(builder, result, builder.getType<SizeType>(), shape, dim);
1346  } else {
1347  Value dim =
1348  builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1349  build(builder, result, builder.getIndexType(), shape, dim);
1350  }
1351 }
1352 
1353 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1354  MLIRContext *context, std::optional<Location> location,
1355  GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1356  inferredReturnTypes.assign({IndexType::get(context)});
1357  return success();
1358 }
1359 
1360 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1361  TypeRange r) {
1362  // SizeType is compatible with IndexType.
1363  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1364 }
1365 
1366 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1367 
1368 //===----------------------------------------------------------------------===//
1369 // IsBroadcastableOp
1370 //===----------------------------------------------------------------------===//
1371 
1372 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1373  MLIRContext *context) {
1374  patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1375 }
1376 
1377 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1378  // Can always broadcast fewer than two shapes.
1379  if (adaptor.getShapes().size() < 2) {
1380  return BoolAttr::get(getContext(), true);
1381  }
1382 
1383  return nullptr;
1384 }
1385 
1386 //===----------------------------------------------------------------------===//
1387 // MeetOp
1388 //===----------------------------------------------------------------------===//
1389 
1390 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1391  MLIRContext *context, std::optional<Location> location,
1392  MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1393  if (adaptor.getOperands().empty())
1394  return failure();
1395 
1396  auto isShapeType = [](Type arg) {
1397  if (llvm::isa<ShapeType>(arg))
1398  return true;
1399  return isExtentTensorType(arg);
1400  };
1401 
1402  ValueRange::type_range types = adaptor.getOperands().getTypes();
1403  Type acc = types.front();
1404  for (auto t : drop_begin(types)) {
1405  Type l = acc, r = t;
1406  if (!llvm::isa<ShapeType, SizeType>(l))
1407  std::swap(l, r);
1408 
1409  // Handle sizes, propagate error type if present.
1410  if (llvm::isa<SizeType>(l)) {
1411  if (llvm::isa<SizeType, IndexType>(r))
1412  acc = l;
1413  else
1414  return emitOptionalError(location, "requires all sizes or shapes");
1415  } else if (llvm::isa<IndexType>(l)) {
1416  if (llvm::isa<IndexType>(r))
1417  acc = r;
1418  else
1419  return emitOptionalError(location, "requires all sizes or shapes");
1420  } else if (llvm::isa<ShapeType>(l)) {
1421  // Handle shapes, propagate error type if present.
1422  if (isShapeType(r))
1423  acc = l;
1424  else
1425  return emitOptionalError(location, "requires all sizes or shapes");
1426  } else if (isExtentTensorType(l)) {
1427  auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1428  auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1429  if (ShapedType::isDynamic(rank1))
1430  acc = l;
1431  else if (ShapedType::isDynamic(rank2))
1432  acc = r;
1433  else if (rank1 != rank2)
1434  return emitOptionalError(location, "unequal shape cardinality");
1435  else
1436  acc = l;
1437  }
1438  }
1439  inferredReturnTypes.assign({acc});
1440  return success();
1441 }
1442 
1443 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1444  if (l.size() != 1 || r.size() != 1)
1445  return false;
1446  if (l == r)
1447  return true;
1448 
1449  Type lhs = l.front();
1450  Type rhs = r.front();
1451 
1452  if (!llvm::isa<ShapeType, SizeType>(lhs))
1453  std::swap(lhs, rhs);
1454 
1455  if (llvm::isa<SizeType>(lhs))
1456  return llvm::isa<SizeType, IndexType>(rhs);
1457  if (llvm::isa<ShapeType>(lhs))
1458  return llvm::isa<ShapeType, TensorType>(rhs);
1459 
1460  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1461  return true;
1462  return false;
1463 }
1464 
1465 //===----------------------------------------------------------------------===//
1466 // RankOp
1467 //===----------------------------------------------------------------------===//
1468 
1469 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1470  auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1471  if (!shape)
1472  return {};
1473  int64_t rank = shape.getNumElements();
1474  Builder builder(getContext());
1475  return builder.getIndexAttr(rank);
1476 }
1477 
1478 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1479 /// Constant folding fails in cases where only the rank is constant, not the
1480 /// shape itself.
1481 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1482 ///
1483 /// Example:
1484 ///
1485 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1486 /// %rank = shape.rank %shape
1487 ///
1488 /// becomes
1489 ///
1490 /// %rank = shape.const_size 3
1491 
1492 namespace {
1493 struct RankShapeOfCanonicalizationPattern
1494  : public OpRewritePattern<shape::RankOp> {
1496 
1497  LogicalResult matchAndRewrite(shape::RankOp op,
1498  PatternRewriter &rewriter) const override {
1499  auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1500  if (!shapeOfOp)
1501  return failure();
1502  auto rankedTensorType =
1503  llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1504  if (!rankedTensorType)
1505  return failure();
1506  int64_t rank = rankedTensorType.getRank();
1507  if (llvm::isa<IndexType>(op.getType())) {
1508  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1509  rank);
1510  } else if (llvm::isa<shape::SizeType>(op.getType())) {
1511  rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1512  } else {
1513  return failure();
1514  }
1515  return success();
1516  }
1517 };
1518 } // namespace
1519 
1520 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1521  MLIRContext *context) {
1522  patterns.add<RankShapeOfCanonicalizationPattern>(context);
1523 }
1524 
1525 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1526  MLIRContext *context, std::optional<Location> location,
1527  RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1528  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1529  inferredReturnTypes.assign({SizeType::get(context)});
1530  else
1531  inferredReturnTypes.assign({IndexType::get(context)});
1532  return success();
1533 }
1534 
1535 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1536  // SizeType is compatible with IndexType.
1537  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1538 }
1539 
1540 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1541 
1542 //===----------------------------------------------------------------------===//
1543 // NumElementsOp
1544 //===----------------------------------------------------------------------===//
1545 
1546 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1547 
1548  // Fold only when argument constant.
1549  Attribute shape = adaptor.getShape();
1550  if (!shape)
1551  return {};
1552 
1553  APInt product(64, 1);
1554  for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1555  product *= value;
1556  Builder builder(getContext());
1557  return builder.getIndexAttr(product.getLimitedValue());
1558 }
1559 
1560 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1561  MLIRContext *context, std::optional<Location> location,
1562  NumElementsOp::Adaptor adaptor,
1563  SmallVectorImpl<Type> &inferredReturnTypes) {
1564  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1565  inferredReturnTypes.assign({SizeType::get(context)});
1566  else
1567  inferredReturnTypes.assign({IndexType::get(context)});
1568  return success();
1569 }
1570 
1571 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1572  TypeRange r) {
1573  // SizeType is compatible with IndexType.
1574  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1575 }
1576 
1577 LogicalResult shape::NumElementsOp::verify() {
1578  return verifySizeOrIndexOp(*this);
1579 }
1580 
1581 //===----------------------------------------------------------------------===//
1582 // MaxOp
1583 //===----------------------------------------------------------------------===//
1584 
1585 OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1586  // If operands are equal, just propagate one.
1587  if (getLhs() == getRhs())
1588  return getLhs();
1589  return nullptr;
1590 }
1591 
1592 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1593  MLIRContext *context, std::optional<Location> location,
1594  MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1595  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1596  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1597  else
1598  inferredReturnTypes.assign({SizeType::get(context)});
1599  return success();
1600 }
1601 
1602 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1603  if (l.size() != 1 || r.size() != 1)
1604  return false;
1605  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1606  return true;
1607  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1608  return true;
1609  return false;
1610 }
1611 
1612 //===----------------------------------------------------------------------===//
1613 // MinOp
1614 //===----------------------------------------------------------------------===//
1615 
1616 OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1617  // If operands are equal, just propagate one.
1618  if (getLhs() == getRhs())
1619  return getLhs();
1620  return nullptr;
1621 }
1622 
1623 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1624  MLIRContext *context, std::optional<Location> location,
1625  MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1626  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1627  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1628  else
1629  inferredReturnTypes.assign({SizeType::get(context)});
1630  return success();
1631 }
1632 
1633 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1634  if (l.size() != 1 || r.size() != 1)
1635  return false;
1636  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1637  return true;
1638  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1639  return true;
1640  return false;
1641 }
1642 
1643 //===----------------------------------------------------------------------===//
1644 // MulOp
1645 //===----------------------------------------------------------------------===//
1646 
1647 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1648  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1649  if (!lhs)
1650  return nullptr;
1651  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1652  if (!rhs)
1653  return nullptr;
1654  APInt folded = lhs.getValue() * rhs.getValue();
1655  Type indexTy = IndexType::get(getContext());
1656  return IntegerAttr::get(indexTy, folded);
1657 }
1658 
1659 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1660  MLIRContext *context, std::optional<Location> location,
1661  MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1662  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1663  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1664  inferredReturnTypes.assign({SizeType::get(context)});
1665  else
1666  inferredReturnTypes.assign({IndexType::get(context)});
1667  return success();
1668 }
1669 
1670 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1671  // SizeType is compatible with IndexType.
1672  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1673 }
1674 
1675 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1676 
1677 //===----------------------------------------------------------------------===//
1678 // ShapeOfOp
1679 //===----------------------------------------------------------------------===//
1680 
1681 namespace {
1682 /// Replace shape_of(x) where x has a constant shape with a const_shape op.
1683 struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1685 
1686  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1687  PatternRewriter &rewriter) const override {
1688  auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1689  if (!type || !type.hasStaticShape())
1690  return failure();
1691  Location loc = op.getLoc();
1692  Value constShape =
1693  rewriter
1694  .create<ConstShapeOp>(loc,
1695  rewriter.getIndexTensorAttr(type.getShape()))
1696  .getResult();
1697  if (constShape.getType() != op.getResult().getType())
1698  constShape = rewriter.create<tensor::CastOp>(
1699  loc, op.getResult().getType(), constShape);
1700  rewriter.replaceOp(op, constShape);
1701  return success();
1702  }
1703 };
1704 
1705 // Canonicalize
1706 //
1707 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1708 // %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1709 //
1710 // to
1711 //
1712 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1713 // %1 = %shape
1714 //
1715 struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
1717 
1718  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1719  PatternRewriter &rewriter) const override {
1720  auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1721  if (!tensorReshapeOp)
1722  return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
1723  if (!isa<TensorType>(op.getType()))
1724  return rewriter.notifyMatchFailure(op, "result is not a tensor");
1725 
1726  // Operand 'shape' of 'tensor.reshape' may now be used as the result of
1727  // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1728  // formed IR, it may not be identical (dynamically vs statically shaped),
1729  // in which case it needs to be cast first.
1730  Value shape = tensorReshapeOp.getShape();
1731  if (op.getType() != shape.getType())
1732  shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
1733 
1734  rewriter.replaceOp(op, shape);
1735  return success();
1736  }
1737 };
1738 
1739 // Canonicalize
1740 // ```
1741 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1742 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1743 // ```
1744 // to
1745 // ```
1746 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1747 // ```
1748 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1750 
1751  LogicalResult matchAndRewrite(tensor::CastOp op,
1752  PatternRewriter &rewriter) const override {
1753  auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1754  if (!ty || ty.getRank() != 1)
1755  return failure();
1756 
1757  auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1758  if (!shapeOfOp)
1759  return failure();
1760 
1761  // Argument type must be ranked and must not conflict.
1762  auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1763  if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1764  return failure();
1765 
1766  rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1767  return success();
1768  }
1769 };
1770 } // namespace
1771 
1772 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1773  MLIRContext *context) {
1774  patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1775  ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1776  context);
1777 }
1778 
1779 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1780  MLIRContext *context, std::optional<Location> location,
1781  ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1782  if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1783  inferredReturnTypes.assign({ShapeType::get(context)});
1784  else {
1785  auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1786  int64_t rank =
1787  shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1788  Type indexTy = IndexType::get(context);
1789  Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1790  inferredReturnTypes.assign({extentTensorTy});
1791  }
1792  return success();
1793 }
1794 
1795 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1796  if (l.size() != 1 || r.size() != 1)
1797  return false;
1798  if (l == r)
1799  return true;
1800 
1801  Type lhs = l.front();
1802  Type rhs = r.front();
1803 
1804  if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1805  !llvm::isa<ShapeType, ShapedType>(rhs))
1806  return false;
1807 
1808  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1809  // Shape type is compatible with all other valid return types.
1810  return true;
1811 
1812  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1813  return true;
1814  return false;
1815 }
1816 
1817 LogicalResult shape::ShapeOfOp::verify() {
1818  return verifyShapeOrExtentTensorOp(*this);
1819 }
1820 
1821 //===----------------------------------------------------------------------===//
1822 // SizeToIndexOp
1823 //===----------------------------------------------------------------------===//
1824 
1825 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1826  // Constant values of both types, `shape.size` and `index`, are represented as
1827  // `IntegerAttr`s which makes constant folding simple.
1828  if (Attribute arg = adaptor.getArg())
1829  return arg;
1830  return OpFoldResult();
1831 }
1832 
1833 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1834  MLIRContext *context) {
1835  patterns.add<IndexToSizeToIndexCanonicalization>(context);
1836 }
1837 
1838 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1839  if (inputs.size() != 1 || outputs.size() != 1)
1840  return false;
1841  return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1842  llvm::isa<IndexType>(outputs[0]);
1843 }
1844 
1845 //===----------------------------------------------------------------------===//
1846 // YieldOp
1847 //===----------------------------------------------------------------------===//
1848 
1849 LogicalResult shape::YieldOp::verify() {
1850  auto *parentOp = (*this)->getParentOp();
1851  auto results = parentOp->getResults();
1852  auto operands = getOperands();
1853 
1854  if (parentOp->getNumResults() != getNumOperands())
1855  return emitOpError() << "number of operands does not match number of "
1856  "results of its parent";
1857  for (auto e : llvm::zip(results, operands))
1858  if (std::get<0>(e).getType() != std::get<1>(e).getType())
1859  return emitOpError() << "types mismatch between yield op and its parent";
1860 
1861  return success();
1862 }
1863 
1864 //===----------------------------------------------------------------------===//
1865 // SplitAtOp
1866 //===----------------------------------------------------------------------===//
1867 
1868 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1869  SmallVectorImpl<OpFoldResult> &results) {
1870  if (!adaptor.getOperand() || !adaptor.getIndex())
1871  return failure();
1872  auto shapeVec = llvm::to_vector<6>(
1873  llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1874  auto shape = llvm::ArrayRef(shapeVec);
1875  auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1876  // Verify that the split point is in the correct range.
1877  // TODO: Constant fold to an "error".
1878  int64_t rank = shape.size();
1879  if (-rank > splitPoint || splitPoint > rank)
1880  return failure();
1881  if (splitPoint < 0)
1882  splitPoint += shape.size();
1883  Builder builder(adaptor.getOperand().getContext());
1884  results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1885  results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1886  return success();
1887 }
1888 
1889 //===----------------------------------------------------------------------===//
1890 // ToExtentTensorOp
1891 //===----------------------------------------------------------------------===//
1892 
1893 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1894  if (!adaptor.getInput())
1895  return OpFoldResult();
1896  Builder builder(getContext());
1897  auto shape = llvm::to_vector<6>(
1898  llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1899  auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1900  builder.getIndexType());
1901  return DenseIntElementsAttr::get(type, shape);
1902 }
1903 
1904 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1905  if (inputs.size() != 1 || outputs.size() != 1)
1906  return false;
1907  if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1908  if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1909  inputTensor.getRank() != 1)
1910  return false;
1911  } else if (!llvm::isa<ShapeType>(inputs[0])) {
1912  return false;
1913  }
1914 
1915  TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1916  return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1917 }
1918 
1919 //===----------------------------------------------------------------------===//
1920 // ReduceOp
1921 //===----------------------------------------------------------------------===//
1922 
1923 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1924  ValueRange initVals) {
1925  OpBuilder::InsertionGuard g(builder);
1926  result.addOperands(shape);
1927  result.addOperands(initVals);
1928 
1929  Region *bodyRegion = result.addRegion();
1930  Block *bodyBlock = builder.createBlock(
1931  bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
1932 
1933  Type elementType;
1934  if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1935  elementType = tensorType.getElementType();
1936  else
1937  elementType = SizeType::get(builder.getContext());
1938  bodyBlock->addArgument(elementType, shape.getLoc());
1939 
1940  for (Value initVal : initVals) {
1941  bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
1942  result.addTypes(initVal.getType());
1943  }
1944 }
1945 
1946 LogicalResult ReduceOp::verify() {
1947  // Verify block arg types.
1948  Block &block = getRegion().front();
1949 
1950  // The block takes index, extent, and aggregated values as arguments.
1951  auto blockArgsCount = getInitVals().size() + 2;
1952  if (block.getNumArguments() != blockArgsCount)
1953  return emitOpError() << "ReduceOp body is expected to have "
1954  << blockArgsCount << " arguments";
1955 
1956  // The first block argument is the index and must always be of type `index`.
1957  if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1958  return emitOpError(
1959  "argument 0 of ReduceOp body is expected to be of IndexType");
1960 
1961  // The second block argument is the extent and must be of type `size` or
1962  // `index`, depending on whether the reduce operation is applied to a shape or
1963  // to an extent tensor.
1964  Type extentTy = block.getArgument(1).getType();
1965  if (llvm::isa<ShapeType>(getShape().getType())) {
1966  if (!llvm::isa<SizeType>(extentTy))
1967  return emitOpError("argument 1 of ReduceOp body is expected to be of "
1968  "SizeType if the ReduceOp operates on a ShapeType");
1969  } else {
1970  if (!llvm::isa<IndexType>(extentTy))
1971  return emitOpError(
1972  "argument 1 of ReduceOp body is expected to be of IndexType if the "
1973  "ReduceOp operates on an extent tensor");
1974  }
1975 
1976  for (const auto &type : llvm::enumerate(getInitVals()))
1977  if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1978  return emitOpError() << "type mismatch between argument "
1979  << type.index() + 2
1980  << " of ReduceOp body and initial value "
1981  << type.index();
1982  return success();
1983 }
1984 
1985 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1986  // Parse operands.
1988  Type shapeOrExtentTensorType;
1989  if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1991  parser.parseColonType(shapeOrExtentTensorType) ||
1992  parser.parseOptionalArrowTypeList(result.types))
1993  return failure();
1994 
1995  // Resolve operands.
1996  auto initVals = llvm::ArrayRef(operands).drop_front();
1997  if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1998  result.operands) ||
1999  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2000  result.operands))
2001  return failure();
2002 
2003  // Parse the body.
2004  Region *body = result.addRegion();
2005  if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
2006  return failure();
2007 
2008  // Parse attributes.
2009  if (parser.parseOptionalAttrDict(result.attributes))
2010  return failure();
2011 
2012  return success();
2013 }
2014 
2015 void ReduceOp::print(OpAsmPrinter &p) {
2016  p << '(' << getShape() << ", " << getInitVals()
2017  << ") : " << getShape().getType();
2018  p.printOptionalArrowTypeList(getResultTypes());
2019  p << ' ';
2020  p.printRegion(getRegion());
2021  p.printOptionalAttrDict((*this)->getAttrs());
2022 }
2023 
2024 #define GET_OP_CLASSES
2025 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2026 
2027 #define GET_TYPEDEF_CLASSES
2028 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition: Shape.cpp:67
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition: Shape.cpp:962
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition: Shape.cpp:84
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition: Shape.cpp:97
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition: Shape.cpp:72
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:120
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:99
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:233
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:95
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:134
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:450
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
ValueTypeRange< ValueRange > type_range
Definition: ValueRange.h:410
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition: Traits.cpp:25
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
bool isExtentTensorType(Type)
Definition: Shape.cpp:45
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition: Shape.cpp:50
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:41
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:437
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.