MLIR  16.0.0git
Shape.cpp
Go to the documentation of this file.
1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 
16 #include "mlir/Dialect/Traits.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 #include "llvm/ADT/SetOperations.h"
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 using namespace mlir;
31 using namespace mlir::shape;
32 
33 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
34 
35 namespace {
36 #include "ShapeCanonicalization.inc"
37 } // namespace
38 
39 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
40  return RankedTensorType::get({rank}, IndexType::get(ctx));
41 }
42 
44  auto ranked = type.dyn_cast<RankedTensorType>();
45  return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
46 }
47 
49  SmallVectorImpl<int64_t> &shapeValues) {
50  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
51  auto type = inputOp.getArg().getType().cast<ShapedType>();
52  if (!type.hasRank())
53  return failure();
54  llvm::append_range(shapeValues, type.getShape());
55  return success();
56  }
58  if (matchPattern(input, m_Constant(&attr))) {
59  llvm::append_range(shapeValues, attr.getValues<int64_t>());
60  return success();
61  }
62  return failure();
63 }
64 
65 static bool isErrorPropagationPossible(TypeRange operandTypes) {
66  return llvm::any_of(operandTypes, [](Type ty) {
67  return ty.isa<SizeType, ShapeType, ValueShapeType>();
68  });
69 }
70 
72  assert(op != nullptr && op->getNumResults() == 1);
73  Type resultTy = op->getResultTypes().front();
75  if (!resultTy.isa<SizeType>())
76  return op->emitOpError()
77  << "if at least one of the operands can hold error values then "
78  "the result must be of type `size` to propagate them";
79  }
80  return success();
81 }
82 
84  assert(op != nullptr && op->getNumResults() == 1);
85  Type resultTy = op->getResultTypes().front();
87  if (!resultTy.isa<ShapeType>())
88  return op->emitOpError()
89  << "if at least one of the operands can hold error values then "
90  "the result must be of type `shape` to propagate them";
91  }
92  return success();
93 }
94 
95 template <typename... Ty>
96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97  return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
98 }
99 
100 template <typename... Ty, typename... ranges>
101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102  return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // InlinerInterface
107 //===----------------------------------------------------------------------===//
108 
109 namespace {
110 /// This class defines the interface for inlining shape dialect ops.
111 struct ShapeInlinerInterface : public DialectInlinerInterface {
113 
114  // Returns true if the given region 'src' can be inlined into the region
115  // 'dest' that is attached to an operation registered to the current dialect.
116  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117  BlockAndValueMapping &) const final {
118  return true;
119  }
120 
121  // Returns true if the given operation 'op', that is registered to this
122  // dialect, can be inlined into the region 'dest' that is attached to an
123  // operation registered to the current dialect.
124  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125  BlockAndValueMapping &) const final {
126  return true;
127  }
128 };
129 } // namespace
130 
131 void ShapeDialect::initialize() {
132  addOperations<
133 #define GET_OP_LIST
134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135  >();
136  addTypes<
137 #define GET_TYPEDEF_LIST
138 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
139  >();
140  addInterfaces<ShapeInlinerInterface>();
141  // Allow unknown operations during prototyping and testing. As the dialect is
142  // still evolving it makes it simple to start with an unregistered ops and
143  // try different variants before actually defining the op.
144  allowUnknownOperations();
145 }
146 
148  Attribute value, Type type,
149  Location loc) {
150  if (type.isa<ShapeType>() || isExtentTensorType(type))
151  return builder.create<ConstShapeOp>(loc, type,
152  value.cast<DenseIntElementsAttr>());
153  if (type.isa<SizeType>())
154  return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
155  if (type.isa<WitnessType>())
156  return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
157  if (arith::ConstantOp::isBuildableWith(value, type))
158  return builder.create<arith::ConstantOp>(loc, type, value);
159  return nullptr;
160 }
161 
162 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
163  NamedAttribute attribute) {
164  // Verify shape.lib attribute.
165  if (attribute.getName() == "shape.lib") {
166  if (!op->hasTrait<OpTrait::SymbolTable>())
167  return op->emitError(
168  "shape.lib attribute may only be on op implementing SymbolTable");
169 
170  if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
171  auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
172  if (!symbol)
173  return op->emitError("shape function library ")
174  << symbolRef << " not found";
175  return isa<shape::FunctionLibraryOp>(symbol)
176  ? success()
177  : op->emitError()
178  << symbolRef << " required to be shape function library";
179  }
180 
181  if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
182  // Verify all entries are function libraries and mappings in libraries
183  // refer to unique ops.
185  for (auto it : arr) {
186  if (!it.isa<SymbolRefAttr>())
187  return op->emitError(
188  "only SymbolRefAttr allowed in shape.lib attribute array");
189 
190  auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
191  SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
192  if (!shapeFnLib)
193  return op->emitError()
194  << it << " does not refer to FunctionLibraryOp";
195  for (auto mapping : shapeFnLib.getMapping()) {
196  if (!key.insert(mapping.getName()).second) {
197  return op->emitError("only one op to shape mapping allowed, found "
198  "multiple for `")
199  << mapping.getName() << "`";
200  }
201  }
202  }
203  return success();
204  }
205 
206  return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
207  "allowed as shape.lib attribute");
208  }
209  return success();
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // AnyOp
214 //===----------------------------------------------------------------------===//
215 
216 // TODO: Canonicalization should be implemented for shapes that can be
217 // determined through mixtures of the known dimensions of the inputs.
218 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
219  // Only the last operand is checked because AnyOp is commutative.
220  if (operands.back())
221  return operands.back();
222 
223  return nullptr;
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // AssumingOp
228 //===----------------------------------------------------------------------===//
229 
230 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
231  result.regions.reserve(1);
232  Region *doRegion = result.addRegion();
233 
234  auto &builder = parser.getBuilder();
236  if (parser.parseOperand(cond) ||
237  parser.resolveOperand(cond, builder.getType<WitnessType>(),
238  result.operands))
239  return failure();
240 
241  // Parse optional results type list.
242  if (parser.parseOptionalArrowTypeList(result.types))
243  return failure();
244 
245  // Parse the region and add a terminator if elided.
246  if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
247  return failure();
248  AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
249 
250  // Parse the optional attribute list.
251  if (parser.parseOptionalAttrDict(result.attributes))
252  return failure();
253  return success();
254 }
255 
257  bool yieldsResults = !getResults().empty();
258 
259  p << " " << getWitness();
260  if (yieldsResults)
261  p << " -> (" << getResultTypes() << ")";
262  p << ' ';
263  p.printRegion(getDoRegion(),
264  /*printEntryBlockArgs=*/false,
265  /*printBlockTerminators=*/yieldsResults);
266  p.printOptionalAttrDict((*this)->getAttrs());
267 }
268 
269 namespace {
270 // Removes AssumingOp with a passing witness and inlines the region.
271 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
273 
274  LogicalResult matchAndRewrite(AssumingOp op,
275  PatternRewriter &rewriter) const override {
276  auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
277  if (!witness || !witness.getPassingAttr())
278  return failure();
279 
280  AssumingOp::inlineRegionIntoParent(op, rewriter);
281  return success();
282  }
283 };
284 
285 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
287 
288  LogicalResult matchAndRewrite(AssumingOp op,
289  PatternRewriter &rewriter) const override {
290  Block *body = op.getBody();
291  auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
292 
293  // Find used values.
294  SmallVector<Value, 4> newYieldOperands;
295  for (auto [opResult, yieldOperand] :
296  llvm::zip(op.getResults(), yieldOp.getOperands())) {
297  if (!opResult.getUses().empty()) {
298  newYieldOperands.push_back(yieldOperand);
299  }
300  }
301 
302  // Rewrite only if redundant results exist.
303  if (newYieldOperands.size() == yieldOp->getNumOperands())
304  return failure();
305 
306  // Replace yield op in the old assuming op's body and move the entire region
307  // to the new assuming op.
308  rewriter.setInsertionPointToEnd(body);
309  auto newYieldOp =
310  rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
311  rewriter.setInsertionPoint(op);
312  auto newOp = rewriter.create<AssumingOp>(
313  op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
314  newOp.getDoRegion().takeBody(op.getDoRegion());
315 
316  // Use the new results to replace the previously used ones.
317  SmallVector<Value, 4> replacementValues;
318  auto src = newOp.getResults().begin();
319  for (auto it : op.getResults()) {
320  if (it.getUses().empty())
321  replacementValues.push_back(nullptr);
322  else
323  replacementValues.push_back(*src++);
324  }
325  rewriter.replaceOp(op, replacementValues);
326  return success();
327  }
328 };
329 } // namespace
330 
331 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
332  MLIRContext *context) {
333  patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
334 }
335 
336 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
337 void AssumingOp::getSuccessorRegions(
338  Optional<unsigned> index, ArrayRef<Attribute> operands,
340  // AssumingOp has unconditional control flow into the region and back to the
341  // parent, so return the correct RegionSuccessor purely based on the index
342  // being None or 0.
343  if (index) {
344  regions.push_back(RegionSuccessor(getResults()));
345  return;
346  }
347 
348  regions.push_back(RegionSuccessor(&getDoRegion()));
349 }
350 
351 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
352  PatternRewriter &rewriter) {
353  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
354  auto *assumingBlock = op.getBody();
355  auto initPosition = rewriter.getInsertionPoint();
356  auto *blockAfterAssuming =
357  rewriter.splitBlock(blockBeforeAssuming, initPosition);
358 
359  // Remove the AssumingOp and AssumingYieldOp.
360  auto &yieldOp = assumingBlock->back();
361  rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
362  rewriter.replaceOp(op, yieldOp.getOperands());
363  rewriter.eraseOp(&yieldOp);
364 
365  // Merge blocks together as there was no branching behavior from the
366  // AssumingOp.
367  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
368  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
369 }
370 
371 void AssumingOp::build(
372  OpBuilder &builder, OperationState &result, Value witness,
374 
375  result.addOperands(witness);
376  Region *bodyRegion = result.addRegion();
377  bodyRegion->push_back(new Block);
378  Block &bodyBlock = bodyRegion->front();
379 
380  // Build body.
381  OpBuilder::InsertionGuard guard(builder);
382  builder.setInsertionPointToStart(&bodyBlock);
383  SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
384  builder.create<AssumingYieldOp>(result.location, yieldValues);
385 
386  SmallVector<Type, 2> assumingTypes;
387  for (Value v : yieldValues)
388  assumingTypes.push_back(v.getType());
389  result.addTypes(assumingTypes);
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // AddOp
394 //===----------------------------------------------------------------------===//
395 
396 LogicalResult mlir::shape::AddOp::inferReturnTypes(
397  MLIRContext *context, Optional<Location> location, ValueRange operands,
398  DictionaryAttr attributes, RegionRange regions,
399  SmallVectorImpl<Type> &inferredReturnTypes) {
400  if (operands[0].getType().isa<SizeType>() ||
401  operands[1].getType().isa<SizeType>())
402  inferredReturnTypes.assign({SizeType::get(context)});
403  else
404  inferredReturnTypes.assign({IndexType::get(context)});
405  return success();
406 }
407 
408 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
409  // SizeType is compatible with IndexType.
410  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
411 }
412 
413 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
414  // add(x, 0) -> x
415  if (matchPattern(getRhs(), m_Zero()))
416  return getLhs();
417 
418  return constFoldBinaryOp<IntegerAttr>(
419  operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
420 }
421 
423 
424 //===----------------------------------------------------------------------===//
425 // AssumingAllOp
426 //===----------------------------------------------------------------------===//
427 
428 namespace {
429 
430 // Merge multiple `shape.assuming_all` operations together.
431 //
432 // %0 = shape.assuming_all %w0, %w1
433 // %1 = shape.assuming_all %w2, %0
434 //
435 // to:
436 //
437 // %0 = shape.assuming_all %w0, %w2, %w2
438 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
440 
441  LogicalResult matchAndRewrite(AssumingAllOp op,
442  PatternRewriter &rewriter) const override {
443  SmallVector<Value> operands;
444 
445  for (Value operand : op.getInputs()) {
446  if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
447  operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
448  else
449  operands.push_back(operand);
450  }
451 
452  // We didn't find any other `assuming_all` ops to merge with.
453  if (operands.size() == op.getNumOperands())
454  return failure();
455 
456  // Replace with a new `assuming_all` operation with merged constraints.
457  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
458  return success();
459  }
460 };
461 
462 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
463 // are subsumed by others.
464 //
465 // %0 = shape.cstr_broadcastable %shape0, %shape1
466 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
467 //
468 // %2 = shape.cstr_broadcastable %shape3, %shape4
469 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
470 //
471 // %4 = shape.assuming_all %0, %1, %2, %3
472 //
473 // to:
474 //
475 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
476 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
477 // %2 = shape.assuming_all %0, %1
478 //
479 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
480 // shapes [0, 1] are broadcastable too, and can be removed from the list of
481 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
482 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
483 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
485 
486  LogicalResult matchAndRewrite(AssumingAllOp op,
487  PatternRewriter &rewriter) const override {
488  // Collect all `CstrBroadcastableOp` operands first.
490  for (Value operand : op.getInputs()) {
491  // TODO: Apply this optimization if some of the witnesses are not
492  // produced by the `cstr_broadcastable`.
493  auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
494  if (!broadcastable)
495  return failure();
496 
497  operands.insert(broadcastable);
498  }
499 
500  // Skip trivial `assuming_all` operations.
501  if (operands.size() <= 1)
502  return failure();
503 
504  // Collect shapes checked by `cstr_broadcastable` operands.
506  for (auto cstr : operands) {
507  DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
508  shapes.emplace_back(cstr, std::move(shapesSet));
509  }
510 
511  // Sort by the number of shape operands (larger to smaller).
512  llvm::sort(shapes, [](auto a, auto b) {
513  return a.first.getNumOperands() > b.first.getNumOperands();
514  });
515 
516  // We start from the `cst_broadcastable` operations with largest number of
517  // shape operands, and remove redundant `cst_broadcastable` operations. We
518  // do this until we find a set of `cst_broadcastable` operations with
519  // non-overlapping constraints.
520  SmallVector<CstrBroadcastableOp> markedForErase;
521 
522  for (unsigned i = 0; i < shapes.size(); ++i) {
523  auto isSubset = [&](auto pair) {
524  return llvm::set_is_subset(pair.second, shapes[i].second);
525  };
526 
527  // Keep redundant `cstr_broadcastable` operations to be erased.
528  auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
529  for (auto *it0 = it; it0 < shapes.end(); ++it0)
530  markedForErase.push_back(it0->first);
531  shapes.erase(it, shapes.end());
532  }
533 
534  // We didn't find any operands that could be removed.
535  if (markedForErase.empty())
536  return failure();
537 
538  // Collect non-overlapping `cst_broadcastable` constraints.
539  SmallVector<Value> uniqueConstraints;
540  for (auto &shape : shapes)
541  uniqueConstraints.push_back(shape.first.getResult());
542 
543  // Replace with a new `assuming_all` operation ...
544  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
545 
546  // ... and maybe erase `cstr_broadcastable` ops without uses.
547  for (auto &op : markedForErase)
548  if (op->use_empty())
549  rewriter.eraseOp(op);
550 
551  return success();
552  }
553 };
554 
555 struct AssumingAllToCstrEqCanonicalization
556  : public OpRewritePattern<AssumingAllOp> {
558 
559  LogicalResult matchAndRewrite(AssumingAllOp op,
560  PatternRewriter &rewriter) const override {
561  SmallVector<Value, 8> shapes;
562  for (Value w : op.getInputs()) {
563  auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
564  if (!cstrEqOp)
565  return failure();
566  bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
567  return llvm::is_contained(shapes, s);
568  });
569  if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
570  return failure();
571  shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
572  }
573  rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
574  return success();
575  }
576 };
577 
578 template <typename OpTy>
579 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
581 
582  LogicalResult matchAndRewrite(OpTy op,
583  PatternRewriter &rewriter) const override {
584  // Find unique operands.
585  SetVector<Value> unique(op.operand_begin(), op.operand_end());
586 
587  // Reduce op to equivalent with unique operands.
588  if (unique.size() < op.getNumOperands()) {
589  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
590  unique.takeVector(), op->getAttrs());
591  return success();
592  }
593 
594  return failure();
595  }
596 };
597 } // namespace
598 
599 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
600  MLIRContext *context) {
601  patterns
602  .add<MergeAssumingAllOps, AssumingAllOneOp,
603  AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
604  RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
605 }
606 
607 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
608  // Iterate in reverse to first handle all constant operands. They are
609  // guaranteed to be the tail of the inputs because this is commutative.
610  for (int idx = operands.size() - 1; idx >= 0; idx--) {
611  Attribute a = operands[idx];
612  // Cannot fold if any inputs are not constant;
613  if (!a)
614  return nullptr;
615 
616  // We do not need to keep statically known values after handling them in
617  // this method.
618  getOperation()->eraseOperand(idx);
619 
620  // Always false if any input is statically known false
621  if (!a.cast<BoolAttr>().getValue())
622  return a;
623  }
624  // If this is reached, all inputs were statically known passing.
625  return BoolAttr::get(getContext(), true);
626 }
627 
629  // Ensure that AssumingAllOp contains at least one operand
630  if (getNumOperands() == 0)
631  return emitOpError("no operands specified");
632 
633  return success();
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // BroadcastOp
638 //===----------------------------------------------------------------------===//
639 
640 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
641  if (getShapes().size() == 1) {
642  // Otherwise, we need a cast which would be a canonicalization, not folding.
643  if (getShapes().front().getType() != getType())
644  return nullptr;
645  return getShapes().front();
646  }
647 
648  // TODO: Support folding with more than 2 input shapes
649  if (getShapes().size() > 2)
650  return nullptr;
651 
652  if (!operands[0] || !operands[1])
653  return nullptr;
654  auto lhsShape = llvm::to_vector<6>(
655  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
656  auto rhsShape = llvm::to_vector<6>(
657  operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
658  SmallVector<int64_t, 6> resultShape;
659 
660  // If the shapes are not compatible, we can't fold it.
661  // TODO: Fold to an "error".
662  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
663  return nullptr;
664 
665  Builder builder(getContext());
666  return builder.getIndexTensorAttr(resultShape);
667 }
668 
670  return verifyShapeOrExtentTensorOp(*this);
671 }
672 
673 namespace {
674 template <typename OpTy>
675 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
677 
678  LogicalResult matchAndRewrite(OpTy op,
679  PatternRewriter &rewriter) const override {
680  auto isPotentiallyNonEmptyShape = [](Value shape) {
681  if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
682  if (extentTensorTy.getDimSize(0) == 0)
683  return false;
684  }
685  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
686  if (constShape.getShape().empty())
687  return false;
688  }
689  return true;
690  };
691  auto newOperands = llvm::to_vector<8>(
692  llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
693 
694  // Reduce op to equivalent without empty shape operands.
695  if (newOperands.size() < op.getNumOperands()) {
696  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
697  op->getAttrs());
698  return success();
699  }
700 
701  return failure();
702  }
703 };
704 
705 struct BroadcastForwardSingleOperandPattern
706  : public OpRewritePattern<BroadcastOp> {
708 
709  LogicalResult matchAndRewrite(BroadcastOp op,
710  PatternRewriter &rewriter) const override {
711  if (op.getNumOperands() != 1)
712  return failure();
713  Value replacement = op.getShapes().front();
714 
715  // Insert cast if needed.
716  if (replacement.getType() != op.getType()) {
717  auto loc = op.getLoc();
718  if (op.getType().isa<ShapeType>()) {
719  replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
720  } else {
721  assert(!op.getType().isa<ShapeType>() &&
722  !replacement.getType().isa<ShapeType>() &&
723  "expect extent tensor cast");
724  replacement =
725  rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
726  }
727  }
728 
729  rewriter.replaceOp(op, replacement);
730  return success();
731  }
732 };
733 
734 struct BroadcastFoldConstantOperandsPattern
735  : public OpRewritePattern<BroadcastOp> {
737 
738  LogicalResult matchAndRewrite(BroadcastOp op,
739  PatternRewriter &rewriter) const override {
740  SmallVector<int64_t, 8> foldedConstantShape;
741  SmallVector<Value, 8> newShapeOperands;
742  for (Value shape : op.getShapes()) {
743  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
744  SmallVector<int64_t, 8> newFoldedConstantShape;
746  foldedConstantShape,
747  llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
748  newFoldedConstantShape)) {
749  foldedConstantShape = newFoldedConstantShape;
750  continue;
751  }
752  }
753  newShapeOperands.push_back(shape);
754  }
755 
756  // Need at least two constant operands to fold anything.
757  if (op.getNumOperands() - newShapeOperands.size() < 2)
758  return failure();
759 
760  auto foldedConstantOperandsTy = RankedTensorType::get(
761  {static_cast<int64_t>(foldedConstantShape.size())},
762  rewriter.getIndexType());
763  newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
764  op.getLoc(), foldedConstantOperandsTy,
765  rewriter.getIndexTensorAttr(foldedConstantShape)));
766  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
767  newShapeOperands);
768  return success();
769  }
770 };
771 
772 template <typename OpTy>
773 struct CanonicalizeCastExtentTensorOperandsPattern
774  : public OpRewritePattern<OpTy> {
776 
777  LogicalResult matchAndRewrite(OpTy op,
778  PatternRewriter &rewriter) const override {
779  // Canonicalize operands.
780  bool anyChange = false;
781  auto canonicalizeOperand = [&](Value operand) -> Value {
782  if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
783  // Only eliminate the cast if it holds no shape information.
784  bool isInformationLoosingCast =
785  castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
786  if (isInformationLoosingCast) {
787  anyChange = true;
788  return castOp.getSource();
789  }
790  }
791  return operand;
792  };
793  auto newOperands = llvm::to_vector<8>(
794  llvm::map_range(op.getOperands(), canonicalizeOperand));
795 
796  // Rewrite op if any change required.
797  if (!anyChange)
798  return failure();
799  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
800  return success();
801  }
802 };
803 
804 struct BroadcastConcretizeResultTypePattern
805  : public OpRewritePattern<BroadcastOp> {
807 
808  LogicalResult matchAndRewrite(BroadcastOp op,
809  PatternRewriter &rewriter) const override {
810  // Only concretize dynamic extent tensor result types.
811  auto resultTy = op.getType().dyn_cast<RankedTensorType>();
812  if (!resultTy || !resultTy.isDynamicDim(0))
813  return failure();
814 
815  // Infer resulting shape rank if possible.
816  int64_t maxRank = 0;
817  for (Value shape : op.getShapes()) {
818  if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
819  // Cannot infer resulting shape rank if any operand is dynamically
820  // ranked.
821  if (extentTensorTy.isDynamicDim(0))
822  return failure();
823  maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
824  }
825  }
826 
827  auto newOp = rewriter.create<BroadcastOp>(
828  op.getLoc(), getExtentTensorType(getContext(), maxRank),
829  op.getShapes());
830  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
831  return success();
832  }
833 };
834 } // namespace
835 
836 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
837  MLIRContext *context) {
838  patterns.add<BroadcastConcretizeResultTypePattern,
839  BroadcastFoldConstantOperandsPattern,
840  BroadcastForwardSingleOperandPattern,
841  CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
842  RemoveDuplicateOperandsPattern<BroadcastOp>,
843  RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
844 }
845 
846 //===----------------------------------------------------------------------===//
847 // ConcatOp
848 //===----------------------------------------------------------------------===//
849 
850 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
851  if (!operands[0] || !operands[1])
852  return nullptr;
853  auto lhsShape = llvm::to_vector<6>(
854  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
855  auto rhsShape = llvm::to_vector<6>(
856  operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
857  SmallVector<int64_t, 6> resultShape;
858  resultShape.append(lhsShape.begin(), lhsShape.end());
859  resultShape.append(rhsShape.begin(), rhsShape.end());
860  Builder builder(getContext());
861  return builder.getIndexTensorAttr(resultShape);
862 }
863 
864 //===----------------------------------------------------------------------===//
865 // ConstShapeOp
866 //===----------------------------------------------------------------------===//
867 
869  p << " ";
870  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
871  p << "[";
872  interleaveComma(getShape().getValues<int64_t>(), p);
873  p << "] : ";
874  p.printType(getType());
875 }
876 
877 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
878  if (parser.parseOptionalAttrDict(result.attributes))
879  return failure();
880  // We piggy-back on ArrayAttr parsing, though we don't internally store the
881  // shape as an ArrayAttr.
882  // TODO: Implement custom parser and maybe make syntax a bit more concise.
883  Attribute extentsRaw;
884  NamedAttrList dummy;
885  if (parser.parseAttribute(extentsRaw, "dummy", dummy))
886  return failure();
887  auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
888  if (!extentsArray)
889  return failure();
891  for (Attribute extent : extentsArray) {
892  IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
893  if (!attr)
894  return failure();
895  ints.push_back(attr.getInt());
896  }
897  Builder &builder = parser.getBuilder();
898  result.addAttribute("shape", builder.getIndexTensorAttr(ints));
899  Type resultTy;
900  if (parser.parseColonType(resultTy))
901  return failure();
902  result.types.push_back(resultTy);
903  return success();
904 }
905 
906 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
907 
908 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
909  MLIRContext *context) {
910  patterns.add<TensorCastConstShape>(context);
911 }
912 
913 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
914  MLIRContext *context, Optional<Location> location, ValueRange operands,
915  DictionaryAttr attributes, RegionRange regions,
916  SmallVectorImpl<Type> &inferredReturnTypes) {
917  Builder b(context);
918  auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
919  if (!shape)
920  return emitOptionalError(location, "missing shape attribute");
921  inferredReturnTypes.assign({RankedTensorType::get(
922  {static_cast<int64_t>(shape.size())}, b.getIndexType())});
923  return success();
924 }
925 
926 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
927  TypeRange r) {
928  if (l.size() != 1 || r.size() != 1)
929  return false;
930 
931  Type lhs = l.front();
932  Type rhs = r.front();
933 
934  if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
935  // Shape type is compatible with all other valid return types.
936  return true;
937  return lhs == rhs;
938 }
939 
940 //===----------------------------------------------------------------------===//
941 // CstrBroadcastableOp
942 //===----------------------------------------------------------------------===//
943 
944 void CstrBroadcastableOp::getCanonicalizationPatterns(
945  RewritePatternSet &patterns, MLIRContext *context) {
946  // Canonicalization patterns have overlap with the considerations during
947  // folding in case additional shape information is inferred at some point that
948  // does not result in folding.
949  patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
950  CstrBroadcastableEqOps,
951  RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
952  RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
953 }
954 
955 // Return true if there is exactly one attribute not representing a scalar
956 // broadcast.
958  bool nonScalarSeen = false;
959  for (Attribute a : attributes) {
960  if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
961  if (nonScalarSeen)
962  return false;
963  nonScalarSeen = true;
964  }
965  }
966  return true;
967 }
968 
969 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
970  // No broadcasting is needed if all operands but one are scalar.
971  if (hasAtMostSingleNonScalar(operands))
972  return BoolAttr::get(getContext(), true);
973 
974  if ([&] {
976  for (const auto &operand : operands) {
977  if (!operand)
978  return false;
979  extents.push_back(llvm::to_vector<6>(
980  operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
981  }
983  }())
984  return BoolAttr::get(getContext(), true);
985 
986  // Lastly, see if folding can be completed based on what constraints are known
987  // on the input shapes.
988  if ([&] {
990  for (auto shapeValue : getShapes()) {
991  extents.emplace_back();
992  if (failed(getShapeVec(shapeValue, extents.back())))
993  return false;
994  }
996  }())
997  return BoolAttr::get(getContext(), true);
998 
999  // Because a failing witness result here represents an eventual assertion
1000  // failure, we do not replace it with a constant witness.
1001  return nullptr;
1002 }
1003 
1005  // Ensure that CstrBroadcastableOp contains at least two operands
1006  if (getNumOperands() < 2)
1007  return emitOpError("required at least 2 input shapes");
1008  return success();
1009 }
1010 
1011 //===----------------------------------------------------------------------===//
1012 // CstrEqOp
1013 //===----------------------------------------------------------------------===//
1014 
1015 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1016  MLIRContext *context) {
1017  // If inputs are equal, return passing witness
1018  patterns.add<CstrEqEqOps>(context);
1019 }
1020 
1021 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
1022  if (llvm::all_of(operands,
1023  [&](Attribute a) { return a && a == operands[0]; }))
1024  return BoolAttr::get(getContext(), true);
1025 
1026  // Because a failing witness result here represents an eventual assertion
1027  // failure, we do not try to replace it with a constant witness. Similarly, we
1028  // cannot if there are any non-const inputs.
1029  return nullptr;
1030 }
1031 
1032 //===----------------------------------------------------------------------===//
1033 // ConstSizeOp
1034 //===----------------------------------------------------------------------===//
1035 
1036 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1037  int64_t value) {
1038  build(builder, result, builder.getIndexAttr(value));
1039 }
1040 
1041 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
1042 
1043 void ConstSizeOp::getAsmResultNames(
1044  llvm::function_ref<void(Value, StringRef)> setNameFn) {
1045  SmallString<4> buffer;
1046  llvm::raw_svector_ostream os(buffer);
1047  os << "c" << getValue();
1048  setNameFn(getResult(), os.str());
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // ConstWitnessOp
1053 //===----------------------------------------------------------------------===//
1054 
1055 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
1056  return getPassingAttr();
1057 }
1058 
1059 //===----------------------------------------------------------------------===//
1060 // CstrRequireOp
1061 //===----------------------------------------------------------------------===//
1062 
1063 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
1064  return operands[0];
1065 }
1066 
1067 //===----------------------------------------------------------------------===//
1068 // DimOp
1069 //===----------------------------------------------------------------------===//
1070 
1071 Optional<int64_t> DimOp::getConstantIndex() {
1072  if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1073  return constSizeOp.getValue().getLimitedValue();
1074  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1075  return constantOp.getValue().cast<IntegerAttr>().getInt();
1076  return llvm::None;
1077 }
1078 
1079 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
1080  Type valType = getValue().getType();
1081  auto valShapedType = valType.dyn_cast<ShapedType>();
1082  if (!valShapedType || !valShapedType.hasRank())
1083  return nullptr;
1084  Optional<int64_t> index = getConstantIndex();
1085  if (!index.has_value())
1086  return nullptr;
1087  if (index.value() >= valShapedType.getRank())
1088  return nullptr;
1089  auto extent = valShapedType.getDimSize(*index);
1090  if (ShapedType::isDynamic(extent))
1091  return nullptr;
1092  return IntegerAttr::get(IndexType::get(getContext()), extent);
1093 }
1094 
1095 LogicalResult mlir::shape::DimOp::inferReturnTypes(
1096  MLIRContext *context, Optional<Location> location, ValueRange operands,
1097  DictionaryAttr attributes, RegionRange regions,
1098  SmallVectorImpl<Type> &inferredReturnTypes) {
1099  DimOpAdaptor dimOp(operands);
1100  inferredReturnTypes.assign({dimOp.getIndex().getType()});
1101  return success();
1102 }
1103 
1104 bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1105  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1106 }
1107 
1109  auto st = getValue().getType().cast<ShapedType>();
1110  if (!st.hasRank())
1111  return success();
1112  if (auto index = getConstantIndex()) {
1113  if (*index < 0 || *index >= st.getRank())
1114  return emitOpError("index is out of range");
1115  }
1116  return success();
1117 }
1118 
1119 //===----------------------------------------------------------------------===//
1120 // DivOp
1121 //===----------------------------------------------------------------------===//
1122 
1123 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1124  auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1125  if (!lhs)
1126  return nullptr;
1127  auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1128  if (!rhs)
1129  return nullptr;
1130 
1131  // Division in APInt does not follow floor(lhs, rhs) when the result is
1132  // negative. Rather, APInt rounds toward zero.
1133  APInt quotient, remainder;
1134  APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1135  if (quotient.isNegative() && !remainder.isNullValue()) {
1136  quotient -= 1;
1137  }
1138 
1139  Type indexTy = IndexType::get(getContext());
1140  return IntegerAttr::get(indexTy, quotient);
1141 }
1142 
1143 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1144  MLIRContext *context, Optional<Location> location, ValueRange operands,
1145  DictionaryAttr attributes, RegionRange regions,
1146  SmallVectorImpl<Type> &inferredReturnTypes) {
1147  if (operands[0].getType().isa<SizeType>() ||
1148  operands[1].getType().isa<SizeType>())
1149  inferredReturnTypes.assign({SizeType::get(context)});
1150  else
1151  inferredReturnTypes.assign({IndexType::get(context)});
1152  return success();
1153 }
1154 
1155 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1156  // SizeType is compatible with IndexType.
1157  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1158 }
1159 
1160 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1161 
1162 //===----------------------------------------------------------------------===//
1163 // ShapeEqOp
1164 //===----------------------------------------------------------------------===//
1165 
1166 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1167  bool allSame = true;
1168  if (!operands.empty() && !operands[0])
1169  return {};
1170  for (Attribute operand : operands.drop_front(1)) {
1171  if (!operand)
1172  return {};
1173  allSame = allSame && operand == operands[0];
1174  }
1175  return BoolAttr::get(getContext(), allSame);
1176 }
1177 
1178 //===----------------------------------------------------------------------===//
1179 // IndexToSizeOp
1180 //===----------------------------------------------------------------------===//
1181 
1182 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1183  // Constant values of both types, `shape.size` and `index`, are represented as
1184  // `IntegerAttr`s which makes constant folding simple.
1185  if (Attribute arg = operands[0])
1186  return arg;
1187  return {};
1188 }
1189 
1190 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1191  MLIRContext *context) {
1192  patterns.add<SizeToIndexToSizeCanonicalization>(context);
1193 }
1194 
1195 //===----------------------------------------------------------------------===//
1196 // FromExtentsOp
1197 //===----------------------------------------------------------------------===//
1198 
1199 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1200  if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1201  return nullptr;
1202  SmallVector<int64_t, 6> extents;
1203  for (auto attr : operands)
1204  extents.push_back(attr.cast<IntegerAttr>().getInt());
1205  Builder builder(getContext());
1206  return builder.getIndexTensorAttr(extents);
1207 }
1208 
1209 //===----------------------------------------------------------------------===//
1210 // FunctionLibraryOp
1211 //===----------------------------------------------------------------------===//
1212 
1213 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1214  StringRef name) {
1215  result.attributes.push_back(builder.getNamedAttr(
1217 }
1218 
1219 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1220  auto attr = getMapping()
1221  .get(op->getName().getIdentifier())
1222  .dyn_cast_or_null<FlatSymbolRefAttr>();
1223  if (!attr)
1224  return nullptr;
1225  return lookupSymbol<FuncOp>(attr);
1226 }
1227 
1228 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1229  OperationState &result) {
1230  // Parse the op name.
1231  StringAttr nameAttr;
1232  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1233  result.attributes))
1234  return failure();
1235 
1236  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1237  return failure();
1238 
1239  auto *bodyRegion = result.addRegion();
1240  if (parser.parseRegion(*bodyRegion))
1241  return failure();
1242 
1243  if (parser.parseKeyword("mapping"))
1244  return failure();
1245 
1246  DictionaryAttr mappingAttr;
1247  if (parser.parseAttribute(mappingAttr,
1248  parser.getBuilder().getType<NoneType>(), "mapping",
1249  result.attributes))
1250  return failure();
1251  return success();
1252 }
1253 
1255  p << ' ';
1256  p.printSymbolName(getName());
1258  (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1259  p << ' ';
1260  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1261  /*printBlockTerminators=*/false);
1262  p << " mapping ";
1263  p.printAttributeWithoutType(getMappingAttr());
1264 }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // FuncOp
1268 //===----------------------------------------------------------------------===//
1269 
1270 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1271  ArrayRef<NamedAttribute> attrs) {
1272  OpBuilder builder(location->getContext());
1273  OperationState state(location, getOperationName());
1274  FuncOp::build(builder, state, name, type, attrs);
1275  return cast<FuncOp>(Operation::create(state));
1276 }
1277 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1279  SmallVector<NamedAttribute, 8> attrRef(attrs);
1280  return create(location, name, type, llvm::makeArrayRef(attrRef));
1281 }
1282 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1284  ArrayRef<DictionaryAttr> argAttrs) {
1285  FuncOp func = create(location, name, type, attrs);
1286  func.setAllArgAttrs(argAttrs);
1287  return func;
1288 }
1289 
1290 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1291  FunctionType type, ArrayRef<NamedAttribute> attrs,
1292  ArrayRef<DictionaryAttr> argAttrs) {
1293  state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1294  builder.getStringAttr(name));
1295  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1296  TypeAttr::get(type));
1297  state.attributes.append(attrs.begin(), attrs.end());
1298  state.addRegion();
1299 
1300  if (argAttrs.empty())
1301  return;
1302  assert(type.getNumInputs() == argAttrs.size());
1303  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
1304  /*resultAttrs=*/llvm::None);
1305 }
1306 
1307 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1308  auto buildFuncType =
1309  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1311  std::string &) { return builder.getFunctionType(argTypes, results); };
1312 
1314  parser, result, /*allowVariadic=*/false, buildFuncType);
1315 }
1316 
1317 void FuncOp::print(OpAsmPrinter &p) {
1318  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
1319 }
1320 
1321 //===----------------------------------------------------------------------===//
1322 // GetExtentOp
1323 //===----------------------------------------------------------------------===//
1324 
1325 Optional<int64_t> GetExtentOp::getConstantDim() {
1326  if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1327  return constSizeOp.getValue().getLimitedValue();
1328  if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1329  return constantOp.getValue().cast<IntegerAttr>().getInt();
1330  return llvm::None;
1331 }
1332 
1333 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1334  auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1335  if (!elements)
1336  return nullptr;
1337  Optional<int64_t> dim = getConstantDim();
1338  if (!dim.has_value())
1339  return nullptr;
1340  if (dim.value() >= elements.getNumElements())
1341  return nullptr;
1342  return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1343 }
1344 
1345 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1346  int64_t dim) {
1347  auto loc = result.location;
1348  auto dimAttr = builder.getIndexAttr(dim);
1349  if (shape.getType().isa<ShapeType>()) {
1350  Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1351  build(builder, result, builder.getType<SizeType>(), shape, dim);
1352  } else {
1353  Value dim =
1354  builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1355  build(builder, result, builder.getIndexType(), shape, dim);
1356  }
1357 }
1358 
1359 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1360  MLIRContext *context, Optional<Location> location, ValueRange operands,
1361  DictionaryAttr attributes, RegionRange regions,
1362  SmallVectorImpl<Type> &inferredReturnTypes) {
1363  inferredReturnTypes.assign({IndexType::get(context)});
1364  return success();
1365 }
1366 
1367 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1368  TypeRange r) {
1369  // SizeType is compatible with IndexType.
1370  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1371 }
1372 
1374 
1375 //===----------------------------------------------------------------------===//
1376 // IsBroadcastableOp
1377 //===----------------------------------------------------------------------===//
1378 
1379 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1380  MLIRContext *context) {
1381  patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1382 }
1383 
1384 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1385  // Can always broadcast fewer than two shapes.
1386  if (operands.size() < 2) {
1387  return BoolAttr::get(getContext(), true);
1388  }
1389 
1390  return nullptr;
1391 }
1392 
1393 //===----------------------------------------------------------------------===//
1394 // MeetOp
1395 //===----------------------------------------------------------------------===//
1396 
1397 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1398  MLIRContext *context, Optional<Location> location, ValueRange operands,
1399  DictionaryAttr attributes, RegionRange regions,
1400  SmallVectorImpl<Type> &inferredReturnTypes) {
1401  if (operands.empty())
1402  return failure();
1403 
1404  auto isShapeType = [](Type arg) {
1405  if (arg.isa<ShapeType>())
1406  return true;
1407  return isExtentTensorType(arg);
1408  };
1409 
1410  ValueRange::type_range types = operands.getTypes();
1411  Type acc = types.front();
1412  for (auto t : drop_begin(types)) {
1413  Type l = acc, r = t;
1414  if (!l.isa<ShapeType, SizeType>())
1415  std::swap(l, r);
1416 
1417  // Handle sizes, propagate error type if present.
1418  if (l.isa<SizeType>()) {
1419  if (r.isa<SizeType, IndexType>())
1420  acc = l;
1421  else
1422  return emitOptionalError(location, "requires all sizes or shapes");
1423  } else if (l.isa<IndexType>()) {
1424  if (r.isa<IndexType>())
1425  acc = r;
1426  else
1427  return emitOptionalError(location, "requires all sizes or shapes");
1428  } else if (l.isa<ShapeType>()) {
1429  // Handle shapes, propagate error type if present.
1430  if (isShapeType(r))
1431  acc = l;
1432  else
1433  return emitOptionalError(location, "requires all sizes or shapes");
1434  } else if (isExtentTensorType(l)) {
1435  auto rank1 = l.cast<RankedTensorType>().getShape()[0];
1436  auto rank2 = r.cast<RankedTensorType>().getShape()[0];
1437  if (ShapedType::isDynamic(rank1))
1438  acc = l;
1439  else if (ShapedType::isDynamic(rank2))
1440  acc = r;
1441  else if (rank1 != rank2)
1442  return emitOptionalError(location, "unequal shape cardinality");
1443  else
1444  acc = l;
1445  }
1446  }
1447  inferredReturnTypes.assign({acc});
1448  return success();
1449 }
1450 
1451 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1452  if (l.size() != 1 || r.size() != 1)
1453  return false;
1454  if (l == r)
1455  return true;
1456 
1457  Type lhs = l.front();
1458  Type rhs = r.front();
1459 
1460  if (!lhs.isa<ShapeType, SizeType>())
1461  std::swap(lhs, rhs);
1462 
1463  if (lhs.isa<SizeType>())
1464  return rhs.isa<SizeType, IndexType>();
1465  if (lhs.isa<ShapeType>())
1466  return rhs.isa<ShapeType, TensorType>();
1467 
1468  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1469  return true;
1470  return false;
1471 }
1472 
1473 //===----------------------------------------------------------------------===//
1474 // RankOp
1475 //===----------------------------------------------------------------------===//
1476 
1477 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1478  auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1479  if (!shape)
1480  return {};
1481  int64_t rank = shape.getNumElements();
1482  Builder builder(getContext());
1483  return builder.getIndexAttr(rank);
1484 }
1485 
1486 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1487 /// Constant folding fails in cases where only the rank is constant, not the
1488 /// shape itself.
1489 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1490 ///
1491 /// Example:
1492 ///
1493 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1494 /// %rank = shape.rank %shape
1495 ///
1496 /// becomes
1497 ///
1498 /// %rank = shape.const_size 3
1499 
1500 namespace {
1501 struct RankShapeOfCanonicalizationPattern
1502  : public OpRewritePattern<shape::RankOp> {
1504 
1505  LogicalResult matchAndRewrite(shape::RankOp op,
1506  PatternRewriter &rewriter) const override {
1507  auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1508  if (!shapeOfOp)
1509  return failure();
1510  auto rankedTensorType =
1511  shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1512  if (!rankedTensorType)
1513  return failure();
1514  int64_t rank = rankedTensorType.getRank();
1515  if (op.getType().isa<IndexType>()) {
1516  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1517  rank);
1518  } else if (op.getType().isa<shape::SizeType>()) {
1519  rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1520  } else {
1521  return failure();
1522  }
1523  return success();
1524  }
1525 };
1526 } // namespace
1527 
1528 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1529  MLIRContext *context) {
1530  patterns.add<RankShapeOfCanonicalizationPattern>(context);
1531 }
1532 
1533 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1534  MLIRContext *context, Optional<Location> location, ValueRange operands,
1535  DictionaryAttr attributes, RegionRange regions,
1536  SmallVectorImpl<Type> &inferredReturnTypes) {
1537  if (operands[0].getType().isa<ShapeType>())
1538  inferredReturnTypes.assign({SizeType::get(context)});
1539  else
1540  inferredReturnTypes.assign({IndexType::get(context)});
1541  return success();
1542 }
1543 
1544 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1545  // SizeType is compatible with IndexType.
1546  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1547 }
1548 
1550 
1551 //===----------------------------------------------------------------------===//
1552 // NumElementsOp
1553 //===----------------------------------------------------------------------===//
1554 
1555 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1556 
1557  // Fold only when argument constant.
1558  Attribute shape = operands[0];
1559  if (!shape)
1560  return {};
1561 
1562  APInt product(64, 1);
1563  for (auto value : shape.cast<DenseIntElementsAttr>())
1564  product *= value;
1565  Builder builder(getContext());
1566  return builder.getIndexAttr(product.getLimitedValue());
1567 }
1568 
1569 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1570  MLIRContext *context, Optional<Location> location, ValueRange operands,
1571  DictionaryAttr attributes, RegionRange regions,
1572  SmallVectorImpl<Type> &inferredReturnTypes) {
1573  if (operands[0].getType().isa<ShapeType>())
1574  inferredReturnTypes.assign({SizeType::get(context)});
1575  else
1576  inferredReturnTypes.assign({IndexType::get(context)});
1577  return success();
1578 }
1579 
1580 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1581  TypeRange r) {
1582  // SizeType is compatible with IndexType.
1583  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1584 }
1585 
1587  return verifySizeOrIndexOp(*this);
1588 }
1589 
1590 //===----------------------------------------------------------------------===//
1591 // MaxOp
1592 //===----------------------------------------------------------------------===//
1593 
1594 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1595  // If operands are equal, just propagate one.
1596  if (getLhs() == getRhs())
1597  return getLhs();
1598  return nullptr;
1599 }
1600 
1601 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1602  MLIRContext *context, Optional<Location> location, ValueRange operands,
1603  DictionaryAttr attributes, RegionRange regions,
1604  SmallVectorImpl<Type> &inferredReturnTypes) {
1605  if (operands[0].getType() == operands[1].getType())
1606  inferredReturnTypes.assign({operands[0].getType()});
1607  else
1608  inferredReturnTypes.assign({SizeType::get(context)});
1609  return success();
1610 }
1611 
1612 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1613  if (l.size() != 1 || r.size() != 1)
1614  return false;
1615  if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1616  return true;
1617  if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1618  return true;
1619  return false;
1620 }
1621 
1622 //===----------------------------------------------------------------------===//
1623 // MinOp
1624 //===----------------------------------------------------------------------===//
1625 
1626 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1627  // If operands are equal, just propagate one.
1628  if (getLhs() == getRhs())
1629  return getLhs();
1630  return nullptr;
1631 }
1632 
1633 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1634  MLIRContext *context, Optional<Location> location, ValueRange operands,
1635  DictionaryAttr attributes, RegionRange regions,
1636  SmallVectorImpl<Type> &inferredReturnTypes) {
1637  if (operands[0].getType() == operands[1].getType())
1638  inferredReturnTypes.assign({operands[0].getType()});
1639  else
1640  inferredReturnTypes.assign({SizeType::get(context)});
1641  return success();
1642 }
1643 
1644 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1645  if (l.size() != 1 || r.size() != 1)
1646  return false;
1647  if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1648  return true;
1649  if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1650  return true;
1651  return false;
1652 }
1653 
1654 //===----------------------------------------------------------------------===//
1655 // MulOp
1656 //===----------------------------------------------------------------------===//
1657 
1658 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1659  auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1660  if (!lhs)
1661  return nullptr;
1662  auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1663  if (!rhs)
1664  return nullptr;
1665  APInt folded = lhs.getValue() * rhs.getValue();
1666  Type indexTy = IndexType::get(getContext());
1667  return IntegerAttr::get(indexTy, folded);
1668 }
1669 
1670 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1671  MLIRContext *context, Optional<Location> location, ValueRange operands,
1672  DictionaryAttr attributes, RegionRange regions,
1673  SmallVectorImpl<Type> &inferredReturnTypes) {
1674  if (operands[0].getType().isa<SizeType>() ||
1675  operands[1].getType().isa<SizeType>())
1676  inferredReturnTypes.assign({SizeType::get(context)});
1677  else
1678  inferredReturnTypes.assign({IndexType::get(context)});
1679  return success();
1680 }
1681 
1682 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1683  // SizeType is compatible with IndexType.
1684  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1685 }
1686 
1688 
1689 //===----------------------------------------------------------------------===//
1690 // ShapeOfOp
1691 //===----------------------------------------------------------------------===//
1692 
1693 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1694  auto type = getOperand().getType().dyn_cast<ShapedType>();
1695  if (!type || !type.hasStaticShape())
1696  return nullptr;
1697  Builder builder(getContext());
1698  return builder.getIndexTensorAttr(type.getShape());
1699 }
1700 
1701 namespace {
1702 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1704 
1705  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1706  PatternRewriter &rewriter) const override {
1707  if (!op.getArg().getType().isa<ShapedType>())
1708  return failure();
1709  if (op.getType().isa<ShapedType>())
1710  return failure();
1711 
1712  rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1713  op.getArg());
1714  return success();
1715  }
1716 };
1717 
1718 // Canonicalize
1719 // ```
1720 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1721 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1722 // ```
1723 // to
1724 // ```
1725 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1726 // ```
1727 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1729 
1730  LogicalResult matchAndRewrite(tensor::CastOp op,
1731  PatternRewriter &rewriter) const override {
1732  auto ty = op.getType().dyn_cast<RankedTensorType>();
1733  if (!ty || ty.getRank() != 1)
1734  return failure();
1735 
1736  auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1737  if (!shapeOfOp)
1738  return failure();
1739 
1740  // Argument type must be ranked and must not conflict.
1741  auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1742  if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1743  return failure();
1744 
1745  rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1746  return success();
1747  }
1748 };
1749 } // namespace
1750 
1751 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1752  MLIRContext *context) {
1753  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1754  ExtractFromShapeOfExtentTensor>(context);
1755 }
1756 
1757 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1758  MLIRContext *context, Optional<Location> location, ValueRange operands,
1759  DictionaryAttr attributes, RegionRange regions,
1760  SmallVectorImpl<Type> &inferredReturnTypes) {
1761  if (operands[0].getType().isa<ValueShapeType>())
1762  inferredReturnTypes.assign({ShapeType::get(context)});
1763  else {
1764  auto shapedTy = operands[0].getType().cast<ShapedType>();
1765  int64_t rank =
1766  shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1767  Type indexTy = IndexType::get(context);
1768  Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1769  inferredReturnTypes.assign({extentTensorTy});
1770  }
1771  return success();
1772 }
1773 
1774 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1775  if (l.size() != 1 || r.size() != 1)
1776  return false;
1777  if (l == r)
1778  return true;
1779 
1780  Type lhs = l.front();
1781  Type rhs = r.front();
1782 
1783  if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1784  return false;
1785 
1786  if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1787  // Shape type is compatible with all other valid return types.
1788  return true;
1789 
1790  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1791  return true;
1792  return false;
1793 }
1794 
1796  return verifyShapeOrExtentTensorOp(*this);
1797 }
1798 
1799 //===----------------------------------------------------------------------===//
1800 // SizeToIndexOp
1801 //===----------------------------------------------------------------------===//
1802 
1803 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1804  // Constant values of both types, `shape.size` and `index`, are represented as
1805  // `IntegerAttr`s which makes constant folding simple.
1806  if (Attribute arg = operands[0])
1807  return arg;
1808  return OpFoldResult();
1809 }
1810 
1811 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1812  MLIRContext *context) {
1813  patterns.add<IndexToSizeToIndexCanonicalization>(context);
1814 }
1815 
1816 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1817  if (inputs.size() != 1 || outputs.size() != 1)
1818  return false;
1819  return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
1820 }
1821 
1822 //===----------------------------------------------------------------------===//
1823 // YieldOp
1824 //===----------------------------------------------------------------------===//
1825 
1827  auto *parentOp = (*this)->getParentOp();
1828  auto results = parentOp->getResults();
1829  auto operands = getOperands();
1830 
1831  if (parentOp->getNumResults() != getNumOperands())
1832  return emitOpError() << "number of operands does not match number of "
1833  "results of its parent";
1834  for (auto e : llvm::zip(results, operands))
1835  if (std::get<0>(e).getType() != std::get<1>(e).getType())
1836  return emitOpError() << "types mismatch between yield op and its parent";
1837 
1838  return success();
1839 }
1840 
1841 //===----------------------------------------------------------------------===//
1842 // SplitAtOp
1843 //===----------------------------------------------------------------------===//
1844 
1845 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1846  SmallVectorImpl<OpFoldResult> &results) {
1847  if (!operands[0] || !operands[1])
1848  return failure();
1849  auto shapeVec = llvm::to_vector<6>(
1850  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1851  auto shape = llvm::makeArrayRef(shapeVec);
1852  auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1853  // Verify that the split point is in the correct range.
1854  // TODO: Constant fold to an "error".
1855  int64_t rank = shape.size();
1856  if (-rank > splitPoint || splitPoint > rank)
1857  return failure();
1858  if (splitPoint < 0)
1859  splitPoint += shape.size();
1860  Builder builder(operands[0].getContext());
1861  results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1862  results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1863  return success();
1864 }
1865 
1866 //===----------------------------------------------------------------------===//
1867 // ToExtentTensorOp
1868 //===----------------------------------------------------------------------===//
1869 
1870 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1871  if (!operands[0])
1872  return OpFoldResult();
1873  Builder builder(getContext());
1874  auto shape = llvm::to_vector<6>(
1875  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1876  auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1877  builder.getIndexType());
1878  return DenseIntElementsAttr::get(type, shape);
1879 }
1880 
1881 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1882  if (inputs.size() != 1 || outputs.size() != 1)
1883  return false;
1884  if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
1885  if (!inputTensor.getElementType().isa<IndexType>() ||
1886  inputTensor.getRank() != 1)
1887  return false;
1888  } else if (!inputs[0].isa<ShapeType>()) {
1889  return false;
1890  }
1891 
1892  TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
1893  return outputTensor && outputTensor.getElementType().isa<IndexType>();
1894 }
1895 
1896 //===----------------------------------------------------------------------===//
1897 // ReduceOp
1898 //===----------------------------------------------------------------------===//
1899 
1900 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1901  ValueRange initVals) {
1902  result.addOperands(shape);
1903  result.addOperands(initVals);
1904 
1905  Region *bodyRegion = result.addRegion();
1906  bodyRegion->push_back(new Block);
1907  Block &bodyBlock = bodyRegion->front();
1908  bodyBlock.addArgument(builder.getIndexType(), result.location);
1909 
1910  Type elementType;
1911  if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1912  elementType = tensorType.getElementType();
1913  else
1914  elementType = SizeType::get(builder.getContext());
1915  bodyBlock.addArgument(elementType, shape.getLoc());
1916 
1917  for (Value initVal : initVals) {
1918  bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1919  result.addTypes(initVal.getType());
1920  }
1921 }
1922 
1924  // Verify block arg types.
1925  Block &block = getRegion().front();
1926 
1927  // The block takes index, extent, and aggregated values as arguments.
1928  auto blockArgsCount = getInitVals().size() + 2;
1929  if (block.getNumArguments() != blockArgsCount)
1930  return emitOpError() << "ReduceOp body is expected to have "
1931  << blockArgsCount << " arguments";
1932 
1933  // The first block argument is the index and must always be of type `index`.
1934  if (!block.getArgument(0).getType().isa<IndexType>())
1935  return emitOpError(
1936  "argument 0 of ReduceOp body is expected to be of IndexType");
1937 
1938  // The second block argument is the extent and must be of type `size` or
1939  // `index`, depending on whether the reduce operation is applied to a shape or
1940  // to an extent tensor.
1941  Type extentTy = block.getArgument(1).getType();
1942  if (getShape().getType().isa<ShapeType>()) {
1943  if (!extentTy.isa<SizeType>())
1944  return emitOpError("argument 1 of ReduceOp body is expected to be of "
1945  "SizeType if the ReduceOp operates on a ShapeType");
1946  } else {
1947  if (!extentTy.isa<IndexType>())
1948  return emitOpError(
1949  "argument 1 of ReduceOp body is expected to be of IndexType if the "
1950  "ReduceOp operates on an extent tensor");
1951  }
1952 
1953  for (const auto &type : llvm::enumerate(getInitVals()))
1954  if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1955  return emitOpError() << "type mismatch between argument "
1956  << type.index() + 2
1957  << " of ReduceOp body and initial value "
1958  << type.index();
1959  return success();
1960 }
1961 
1962 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1963  // Parse operands.
1965  Type shapeOrExtentTensorType;
1966  if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1968  parser.parseColonType(shapeOrExtentTensorType) ||
1969  parser.parseOptionalArrowTypeList(result.types))
1970  return failure();
1971 
1972  // Resolve operands.
1973  auto initVals = llvm::makeArrayRef(operands).drop_front();
1974  if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1975  result.operands) ||
1976  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1977  result.operands))
1978  return failure();
1979 
1980  // Parse the body.
1981  Region *body = result.addRegion();
1982  if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1983  return failure();
1984 
1985  // Parse attributes.
1986  if (parser.parseOptionalAttrDict(result.attributes))
1987  return failure();
1988 
1989  return success();
1990 }
1991 
1992 void ReduceOp::print(OpAsmPrinter &p) {
1993  p << '(' << getShape() << ", " << getInitVals()
1994  << ") : " << getShape().getType();
1995  p.printOptionalArrowTypeList(getResultTypes());
1996  p << ' ';
1997  p.printRegion(getRegion());
1998  p.printOptionalAttrDict((*this)->getAttrs());
1999 }
2000 
2001 #define GET_OP_CLASSES
2002 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2003 
2004 #define GET_TYPEDEF_CLASSES
2005 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition: Shape.cpp:65
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition: Shape.cpp:957
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition: Shape.cpp:83
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition: Shape.cpp:96
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition: Shape.cpp:71
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static constexpr const bool value
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
U cast() const
Definition: Attributes.h:137
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:20
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
unsigned getNumArguments()
Definition: Block.h:117
Operation & back()
Definition: Block.h:141
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
Operation & front()
Definition: Block.h:142
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:49
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:81
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:88
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:243
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
MLIRContext * getContext() const
Definition: Builders.h:54
IndexType getIndexType()
Definition: Builders.cpp:56
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:95
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:41
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:43
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:150
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:32
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:397
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:388
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:394
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:341
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:49
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
operand_type_range getOperandTypes()
Definition: Operation.h:314
result_type_range getResultTypes()
Definition: Operation.h:345
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:331
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:58
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:78
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
U dyn_cast() const
Definition: Types.h:270
bool isa() const
Definition: Types.h:260
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
ValueTypeRange< ValueRange > type_range
Definition: ValueRange.h:378
type_range getType() const
type_range getTypes() const
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
bool isa() const
Definition: Value.h:90
Type getType() const
Return the type of this value.
Definition: Value.h:114
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition: Traits.cpp:24
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:59
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, FuncTypeBuilder funcTypeBuilder)
Parser implementation for function-like operations.
bool isExtentTensorType(Type)
Definition: Shape.cpp:43
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition: Shape.cpp:48
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:39
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:329
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult emitOptionalError(Optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:490
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:306
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:255
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:372
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.