MLIR  14.0.0git
Shape.cpp
Go to the documentation of this file.
1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 
16 #include "mlir/Dialect/Traits.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/raw_ostream.h"
27 
28 using namespace mlir;
29 using namespace mlir::shape;
30 
31 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
32 
33 namespace {
34 #include "ShapeCanonicalization.inc"
35 } // namespace
36 
37 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
38  return RankedTensorType::get({rank}, IndexType::get(ctx));
39 }
40 
42  auto ranked = type.dyn_cast<RankedTensorType>();
43  return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
44 }
45 
47  SmallVectorImpl<int64_t> &shapeValues) {
48  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
49  auto type = inputOp.getArg().getType().dyn_cast<ShapedType>();
50  if (!type.hasRank())
51  return failure();
52  shapeValues = llvm::to_vector<6>(type.getShape());
53  return success();
54  }
55  if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
56  shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>());
57  return success();
58  }
59  if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {
60  shapeValues = llvm::to_vector<6>(
61  inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>());
62  return success();
63  }
64  return failure();
65 }
66 
67 static bool isErrorPropagationPossible(TypeRange operandTypes) {
68  return llvm::any_of(operandTypes, [](Type ty) {
69  return ty.isa<SizeType, ShapeType, ValueShapeType>();
70  });
71 }
72 
74  assert(op != nullptr && op->getNumResults() == 1);
75  Type resultTy = op->getResultTypes().front();
77  if (!resultTy.isa<SizeType>())
78  return op->emitOpError()
79  << "if at least one of the operands can hold error values then "
80  "the result must be of type `size` to propagate them";
81  }
82  return success();
83 }
84 
86  assert(op != nullptr && op->getNumResults() == 1);
87  Type resultTy = op->getResultTypes().front();
89  if (!resultTy.isa<ShapeType>())
90  return op->emitOpError()
91  << "if at least one of the operands can hold error values then "
92  "the result must be of type `shape` to propagate them";
93  }
94  return success();
95 }
96 
97 template <typename... Ty>
98 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
99  return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
100 }
101 
102 template <typename... Ty, typename... ranges>
103 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
104  return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // InlinerInterface
109 //===----------------------------------------------------------------------===//
110 
111 namespace {
112 /// This class defines the interface for inlining shape dialect ops.
113 struct ShapeInlinerInterface : public DialectInlinerInterface {
115 
116  // Returns true if the given region 'src' can be inlined into the region
117  // 'dest' that is attached to an operation registered to the current dialect.
118  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
119  BlockAndValueMapping &) const final {
120  return true;
121  }
122 
123  // Returns true if the given operation 'op', that is registered to this
124  // dialect, can be inlined into the region 'dest' that is attached to an
125  // operation registered to the current dialect.
126  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
127  BlockAndValueMapping &) const final {
128  return true;
129  }
130 };
131 } // namespace
132 
133 void ShapeDialect::initialize() {
134  addOperations<
135 #define GET_OP_LIST
136 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
137  >();
138  addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
139  addInterfaces<ShapeInlinerInterface>();
140  // Allow unknown operations during prototyping and testing. As the dialect is
141  // still evolving it makes it simple to start with an unregistered ops and
142  // try different variants before actually defining the op.
143  allowUnknownOperations();
144 }
145 
147  Attribute value, Type type,
148  Location loc) {
149  if (type.isa<ShapeType>() || isExtentTensorType(type))
150  return builder.create<ConstShapeOp>(loc, type,
151  value.cast<DenseIntElementsAttr>());
152  if (type.isa<SizeType>())
153  return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
154  if (type.isa<WitnessType>())
155  return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
156  if (arith::ConstantOp::isBuildableWith(value, type))
157  return builder.create<arith::ConstantOp>(loc, type, value);
158  return nullptr;
159 }
160 
161 /// Parse a type registered to this dialect.
163  StringRef keyword;
164  if (parser.parseKeyword(&keyword))
165  return Type();
166 
167  if (keyword == "shape")
168  return ShapeType::get(getContext());
169  if (keyword == "size")
170  return SizeType::get(getContext());
171  if (keyword == "value_shape")
172  return ValueShapeType::get(getContext());
173  if (keyword == "witness")
174  return WitnessType::get(getContext());
175 
176  parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
177  return Type();
178 }
179 
180 /// Print a type registered to this dialect.
181 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
182  TypeSwitch<Type>(type)
183  .Case<ShapeType>([&](Type) { os << "shape"; })
184  .Case<SizeType>([&](Type) { os << "size"; })
185  .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
186  .Case<WitnessType>([&](Type) { os << "witness"; })
187  .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
188 }
189 
190 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
191  NamedAttribute attribute) {
192  // Verify shape.lib attribute.
193  if (attribute.getName() == "shape.lib") {
194  if (!op->hasTrait<OpTrait::SymbolTable>())
195  return op->emitError(
196  "shape.lib attribute may only be on op implementing SymbolTable");
197 
198  if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
199  auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
200  if (!symbol)
201  return op->emitError("shape function library ")
202  << symbolRef << " not found";
203  return isa<shape::FunctionLibraryOp>(symbol)
204  ? success()
205  : op->emitError()
206  << symbolRef << " required to be shape function library";
207  }
208 
209  if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
210  // Verify all entries are function libraries and mappings in libraries
211  // refer to unique ops.
213  for (auto it : arr) {
214  if (!it.isa<SymbolRefAttr>())
215  return op->emitError(
216  "only SymbolRefAttr allowed in shape.lib attribute array");
217 
218  auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
219  SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
220  if (!shapeFnLib)
221  return op->emitError()
222  << it << " does not refer to FunctionLibraryOp";
223  for (auto mapping : shapeFnLib.getMapping()) {
224  if (!key.insert(mapping.getName()).second) {
225  return op->emitError("only one op to shape mapping allowed, found "
226  "multiple for `")
227  << mapping.getName() << "`";
228  }
229  }
230  }
231  return success();
232  }
233 
234  return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
235  "allowed as shape.lib attribute");
236  }
237  return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // AnyOp
242 //===----------------------------------------------------------------------===//
243 
244 // TODO: Canonicalization should be implemented for shapes that can be
245 // determined through mixtures of the known dimensions of the inputs.
246 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
247  // Only the last operand is checked because AnyOp is commutative.
248  if (operands.back())
249  return operands.back();
250 
251  return nullptr;
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // AssumingOp
256 //===----------------------------------------------------------------------===//
257 
259  OperationState &result) {
260  result.regions.reserve(1);
261  Region *doRegion = result.addRegion();
262 
263  auto &builder = parser.getBuilder();
265  if (parser.parseOperand(cond) ||
266  parser.resolveOperand(cond, builder.getType<WitnessType>(),
267  result.operands))
268  return failure();
269 
270  // Parse optional results type list.
271  if (parser.parseOptionalArrowTypeList(result.types))
272  return failure();
273 
274  // Parse the region and add a terminator if elided.
275  if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
276  return failure();
277  AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
278 
279  // Parse the optional attribute list.
280  if (parser.parseOptionalAttrDict(result.attributes))
281  return failure();
282  return success();
283 }
284 
285 static void print(OpAsmPrinter &p, AssumingOp op) {
286  bool yieldsResults = !op.getResults().empty();
287 
288  p << " " << op.getWitness();
289  if (yieldsResults)
290  p << " -> (" << op.getResultTypes() << ")";
291  p << ' ';
292  p.printRegion(op.getDoRegion(),
293  /*printEntryBlockArgs=*/false,
294  /*printBlockTerminators=*/yieldsResults);
295  p.printOptionalAttrDict(op->getAttrs());
296 }
297 
298 namespace {
299 // Removes AssumingOp with a passing witness and inlines the region.
300 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
302 
303  LogicalResult matchAndRewrite(AssumingOp op,
304  PatternRewriter &rewriter) const override {
305  auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
306  if (!witness || !witness.getPassingAttr())
307  return failure();
308 
309  AssumingOp::inlineRegionIntoParent(op, rewriter);
310  return success();
311  }
312 };
313 
314 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
316 
317  LogicalResult matchAndRewrite(AssumingOp op,
318  PatternRewriter &rewriter) const override {
319  Block *body = op.getBody();
320  auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
321 
322  // Find used values.
323  SmallVector<Value, 4> newYieldOperands;
324  Value opResult, yieldOperand;
325  for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
326  std::tie(opResult, yieldOperand) = it;
327  if (!opResult.getUses().empty()) {
328  newYieldOperands.push_back(yieldOperand);
329  }
330  }
331 
332  // Rewrite only if redundant results exist.
333  if (newYieldOperands.size() == yieldOp->getNumOperands())
334  return failure();
335 
336  // Replace yield op in the old assuming op's body and move the entire region
337  // to the new assuming op.
338  rewriter.setInsertionPointToEnd(body);
339  auto newYieldOp =
340  rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
341  rewriter.setInsertionPoint(op);
342  auto newOp = rewriter.create<AssumingOp>(
343  op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
344  newOp.getDoRegion().takeBody(op.getDoRegion());
345 
346  // Use the new results to replace the previously used ones.
347  SmallVector<Value, 4> replacementValues;
348  auto src = newOp.getResults().begin();
349  for (auto it : op.getResults()) {
350  if (it.getUses().empty())
351  replacementValues.push_back(nullptr);
352  else
353  replacementValues.push_back(*src++);
354  }
355  rewriter.replaceOp(op, replacementValues);
356  return success();
357  }
358 };
359 } // namespace
360 
361 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
362  MLIRContext *context) {
363  patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
364 }
365 
366 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
367 void AssumingOp::getSuccessorRegions(
368  Optional<unsigned> index, ArrayRef<Attribute> operands,
370  // AssumingOp has unconditional control flow into the region and back to the
371  // parent, so return the correct RegionSuccessor purely based on the index
372  // being None or 0.
373  if (index.hasValue()) {
374  regions.push_back(RegionSuccessor(getResults()));
375  return;
376  }
377 
378  regions.push_back(RegionSuccessor(&getDoRegion()));
379 }
380 
381 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
382  PatternRewriter &rewriter) {
383  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
384  auto *assumingBlock = op.getBody();
385  auto initPosition = rewriter.getInsertionPoint();
386  auto *blockAfterAssuming =
387  rewriter.splitBlock(blockBeforeAssuming, initPosition);
388 
389  // Remove the AssumingOp and AssumingYieldOp.
390  auto &yieldOp = assumingBlock->back();
391  rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
392  rewriter.replaceOp(op, yieldOp.getOperands());
393  rewriter.eraseOp(&yieldOp);
394 
395  // Merge blocks together as there was no branching behavior from the
396  // AssumingOp.
397  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
398  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
399 }
400 
401 void AssumingOp::build(
402  OpBuilder &builder, OperationState &result, Value witness,
404 
405  result.addOperands(witness);
406  Region *bodyRegion = result.addRegion();
407  bodyRegion->push_back(new Block);
408  Block &bodyBlock = bodyRegion->front();
409 
410  // Build body.
411  OpBuilder::InsertionGuard guard(builder);
412  builder.setInsertionPointToStart(&bodyBlock);
413  SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
414  builder.create<AssumingYieldOp>(result.location, yieldValues);
415 
416  SmallVector<Type, 2> assumingTypes;
417  for (Value v : yieldValues)
418  assumingTypes.push_back(v.getType());
419  result.addTypes(assumingTypes);
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // AddOp
424 //===----------------------------------------------------------------------===//
425 
426 LogicalResult mlir::shape::AddOp::inferReturnTypes(
427  MLIRContext *context, Optional<Location> location, ValueRange operands,
428  DictionaryAttr attributes, RegionRange regions,
429  SmallVectorImpl<Type> &inferredReturnTypes) {
430  if (operands[0].getType().isa<SizeType>() ||
431  operands[1].getType().isa<SizeType>())
432  inferredReturnTypes.assign({SizeType::get(context)});
433  else
434  inferredReturnTypes.assign({IndexType::get(context)});
435  return success();
436 }
437 
438 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
439  // SizeType is compatible with IndexType.
440  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
441 }
442 
443 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
444  // add(x, 0) -> x
445  if (matchPattern(getRhs(), m_Zero()))
446  return getLhs();
447 
448  return constFoldBinaryOp<IntegerAttr>(
449  operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // AssumingAllOp
454 //===----------------------------------------------------------------------===//
455 
456 namespace {
457 struct AssumingAllToCstrEqCanonicalization
458  : public OpRewritePattern<AssumingAllOp> {
460 
461  LogicalResult matchAndRewrite(AssumingAllOp op,
462  PatternRewriter &rewriter) const override {
463  SmallVector<Value, 8> shapes;
464  for (Value w : op.getInputs()) {
465  auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
466  if (!cstrEqOp)
467  return failure();
468  bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
469  return llvm::is_contained(shapes, s);
470  });
471  if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
472  return failure();
473  shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
474  }
475  rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
476  return success();
477  }
478 };
479 
480 template <typename OpTy>
481 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
483 
484  LogicalResult matchAndRewrite(OpTy op,
485  PatternRewriter &rewriter) const override {
486  // Find unique operands.
487  SmallVector<Value, 2> unique;
488  for (Value v : op.getOperands()) {
489  if (!llvm::is_contained(unique, v))
490  unique.push_back(v);
491  }
492 
493  // Reduce op to equivalent with unique operands.
494  if (unique.size() < op.getNumOperands()) {
495  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
496  op->getAttrs());
497  return success();
498  }
499 
500  return failure();
501  }
502 };
503 } // namespace
504 
505 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
506  MLIRContext *context) {
507  patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
508  RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
509 }
510 
511 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
512  // Iterate in reverse to first handle all constant operands. They are
513  // guaranteed to be the tail of the inputs because this is commutative.
514  for (int idx = operands.size() - 1; idx >= 0; idx--) {
515  Attribute a = operands[idx];
516  // Cannot fold if any inputs are not constant;
517  if (!a)
518  return nullptr;
519 
520  // We do not need to keep statically known values after handling them in
521  // this method.
522  getOperation()->eraseOperand(idx);
523 
524  // Always false if any input is statically known false
525  if (!a.cast<BoolAttr>().getValue())
526  return a;
527  }
528  // If this is reached, all inputs were statically known passing.
529  return BoolAttr::get(getContext(), true);
530 }
531 
532 static LogicalResult verify(AssumingAllOp op) {
533  // Ensure that AssumingAllOp contains at least one operand
534  if (op.getNumOperands() == 0)
535  return op.emitOpError("no operands specified");
536 
537  return success();
538 }
539 
540 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
541  ValueRange inputs) {
542  build(b, state, b.getType<WitnessType>(), inputs);
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // BroadcastOp
547 //===----------------------------------------------------------------------===//
548 
549 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
550  if (getShapes().size() == 1) {
551  // Otherwise, we need a cast which would be a canonicalization, not folding.
552  if (getShapes().front().getType() != getType())
553  return nullptr;
554  return getShapes().front();
555  }
556 
557  // TODO: Support folding with more than 2 input shapes
558  if (getShapes().size() > 2)
559  return nullptr;
560 
561  if (!operands[0] || !operands[1])
562  return nullptr;
563  auto lhsShape = llvm::to_vector<6>(
564  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
565  auto rhsShape = llvm::to_vector<6>(
566  operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
567  SmallVector<int64_t, 6> resultShape;
568 
569  // If the shapes are not compatible, we can't fold it.
570  // TODO: Fold to an "error".
571  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
572  return nullptr;
573 
574  Builder builder(getContext());
575  return builder.getIndexTensorAttr(resultShape);
576 }
577 
578 static LogicalResult verify(BroadcastOp op) {
579  return verifyShapeOrExtentTensorOp(op);
580 }
581 
582 namespace {
583 template <typename OpTy>
584 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
586 
587  LogicalResult matchAndRewrite(OpTy op,
588  PatternRewriter &rewriter) const override {
589  auto isPotentiallyNonEmptyShape = [](Value shape) {
590  if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
591  if (extentTensorTy.getDimSize(0) == 0)
592  return false;
593  }
594  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
595  if (constShape.getShape().empty())
596  return false;
597  }
598  return true;
599  };
600  auto newOperands = llvm::to_vector<8>(
601  llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
602 
603  // Reduce op to equivalent without empty shape operands.
604  if (newOperands.size() < op.getNumOperands()) {
605  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
606  op->getAttrs());
607  return success();
608  }
609 
610  return failure();
611  }
612 };
613 
614 struct BroadcastForwardSingleOperandPattern
615  : public OpRewritePattern<BroadcastOp> {
617 
618  LogicalResult matchAndRewrite(BroadcastOp op,
619  PatternRewriter &rewriter) const override {
620  if (op.getNumOperands() != 1)
621  return failure();
622  Value replacement = op.getShapes().front();
623 
624  // Insert cast if needed.
625  if (replacement.getType() != op.getType()) {
626  auto loc = op.getLoc();
627  if (op.getType().isa<ShapeType>()) {
628  replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
629  } else {
630  assert(!op.getType().isa<ShapeType>() &&
631  !replacement.getType().isa<ShapeType>() &&
632  "expect extent tensor cast");
633  replacement =
634  rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
635  }
636  }
637 
638  rewriter.replaceOp(op, replacement);
639  return success();
640  }
641 };
642 
643 struct BroadcastFoldConstantOperandsPattern
644  : public OpRewritePattern<BroadcastOp> {
646 
647  LogicalResult matchAndRewrite(BroadcastOp op,
648  PatternRewriter &rewriter) const override {
649  SmallVector<int64_t, 8> foldedConstantShape;
650  SmallVector<Value, 8> newShapeOperands;
651  for (Value shape : op.getShapes()) {
652  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
653  SmallVector<int64_t, 8> newFoldedConstantShape;
655  foldedConstantShape,
656  llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
657  newFoldedConstantShape)) {
658  foldedConstantShape = newFoldedConstantShape;
659  continue;
660  }
661  }
662  newShapeOperands.push_back(shape);
663  }
664 
665  // Need at least two constant operands to fold anything.
666  if (op.getNumOperands() - newShapeOperands.size() < 2)
667  return failure();
668 
669  auto foldedConstantOperandsTy = RankedTensorType::get(
670  {static_cast<int64_t>(foldedConstantShape.size())},
671  rewriter.getIndexType());
672  newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
673  op.getLoc(), foldedConstantOperandsTy,
674  rewriter.getIndexTensorAttr(foldedConstantShape)));
675  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
676  newShapeOperands);
677  return success();
678  }
679 };
680 
681 template <typename OpTy>
682 struct CanonicalizeCastExtentTensorOperandsPattern
683  : public OpRewritePattern<OpTy> {
685 
686  LogicalResult matchAndRewrite(OpTy op,
687  PatternRewriter &rewriter) const override {
688  // Canonicalize operands.
689  bool anyChange = false;
690  auto canonicalizeOperand = [&](Value operand) {
691  if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
692  // Only eliminate the cast if it holds no shape information.
693  bool isInformationLoosingCast =
694  castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
695  if (isInformationLoosingCast) {
696  anyChange = true;
697  return castOp.source();
698  }
699  }
700  return operand;
701  };
702  auto newOperands = llvm::to_vector<8>(
703  llvm::map_range(op.getOperands(), canonicalizeOperand));
704 
705  // Rewrite op if any change required.
706  if (!anyChange)
707  return failure();
708  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
709  return success();
710  }
711 };
712 
713 struct BroadcastConcretizeResultTypePattern
714  : public OpRewritePattern<BroadcastOp> {
716 
717  LogicalResult matchAndRewrite(BroadcastOp op,
718  PatternRewriter &rewriter) const override {
719  // Only concretize dynamic extent tensor result types.
720  auto resultTy = op.getType().dyn_cast<RankedTensorType>();
721  if (!resultTy || !resultTy.isDynamicDim(0))
722  return failure();
723 
724  // Infer resulting shape rank if possible.
725  int64_t maxRank = 0;
726  for (Value shape : op.getShapes()) {
727  if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
728  // Cannot infer resulting shape rank if any operand is dynamically
729  // ranked.
730  if (extentTensorTy.isDynamicDim(0))
731  return failure();
732  maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
733  }
734  }
735 
736  auto newOp = rewriter.create<BroadcastOp>(
737  op.getLoc(), getExtentTensorType(getContext(), maxRank),
738  op.getShapes());
739  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
740  return success();
741  }
742 };
743 } // namespace
744 
745 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
746  MLIRContext *context) {
747  patterns.add<BroadcastConcretizeResultTypePattern,
748  BroadcastFoldConstantOperandsPattern,
749  BroadcastForwardSingleOperandPattern,
750  CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
751  RemoveDuplicateOperandsPattern<BroadcastOp>,
752  RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // ConcatOp
757 //===----------------------------------------------------------------------===//
758 
759 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
760  if (!operands[0] || !operands[1])
761  return nullptr;
762  auto lhsShape = llvm::to_vector<6>(
763  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
764  auto rhsShape = llvm::to_vector<6>(
765  operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
766  SmallVector<int64_t, 6> resultShape;
767  resultShape.append(lhsShape.begin(), lhsShape.end());
768  resultShape.append(rhsShape.begin(), rhsShape.end());
769  Builder builder(getContext());
770  return builder.getIndexTensorAttr(resultShape);
771 }
772 
773 //===----------------------------------------------------------------------===//
774 // ConstShapeOp
775 //===----------------------------------------------------------------------===//
776 
777 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
778  p << " ";
779  p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
780  p << "[";
781  interleaveComma(op.getShape().getValues<int64_t>(), p,
782  [&](int64_t i) { p << i; });
783  p << "] : ";
784  p.printType(op.getType());
785 }
786 
788  OperationState &result) {
789  if (parser.parseOptionalAttrDict(result.attributes))
790  return failure();
791  // We piggy-back on ArrayAttr parsing, though we don't internally store the
792  // shape as an ArrayAttr.
793  // TODO: Implement custom parser and maybe make syntax a bit more concise.
794  Attribute extentsRaw;
795  NamedAttrList dummy;
796  if (parser.parseAttribute(extentsRaw, "dummy", dummy))
797  return failure();
798  auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
799  if (!extentsArray)
800  return failure();
802  for (Attribute extent : extentsArray) {
803  IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
804  if (!attr)
805  return failure();
806  ints.push_back(attr.getInt());
807  }
808  Builder &builder = parser.getBuilder();
809  result.addAttribute("shape", builder.getIndexTensorAttr(ints));
810  Type resultTy;
811  if (parser.parseColonType(resultTy))
812  return failure();
813  result.types.push_back(resultTy);
814  return success();
815 }
816 
817 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
818 
819 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
820  MLIRContext *context) {
821  patterns.add<TensorCastConstShape>(context);
822 }
823 
824 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
825  MLIRContext *context, Optional<Location> location, ValueRange operands,
826  DictionaryAttr attributes, RegionRange regions,
827  SmallVectorImpl<Type> &inferredReturnTypes) {
828  Builder b(context);
829  auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
830  if (!shape)
831  return emitOptionalError(location, "missing shape attribute");
832  inferredReturnTypes.assign({RankedTensorType::get(
833  {static_cast<int64_t>(shape.size())}, b.getIndexType())});
834  return success();
835 }
836 
837 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
838  TypeRange r) {
839  if (l.size() != 1 || r.size() != 1)
840  return false;
841 
842  Type lhs = l.front();
843  Type rhs = r.front();
844 
845  if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
846  // Shape type is compatible with all other valid return types.
847  return true;
848  return lhs == rhs;
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // CstrBroadcastableOp
853 //===----------------------------------------------------------------------===//
854 
855 void CstrBroadcastableOp::getCanonicalizationPatterns(
856  RewritePatternSet &patterns, MLIRContext *context) {
857  // Canonicalization patterns have overlap with the considerations during
858  // folding in case additional shape information is inferred at some point that
859  // does not result in folding.
860  patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
861  CstrBroadcastableEqOps,
862  RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
863  RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
864 }
865 
866 // Return true if there is exactly one attribute not representing a scalar
867 // broadcast.
869  bool nonScalarSeen = false;
870  for (Attribute a : attributes) {
871  if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
872  if (nonScalarSeen)
873  return false;
874  nonScalarSeen = true;
875  }
876  }
877  return true;
878 }
879 
880 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
881  // No broadcasting is needed if all operands but one are scalar.
882  if (hasAtMostSingleNonScalar(operands))
883  return BoolAttr::get(getContext(), true);
884 
885  if ([&] {
887  for (const auto &operand : operands) {
888  if (!operand)
889  return false;
890  extents.push_back(llvm::to_vector<6>(
891  operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
892  }
894  }())
895  return BoolAttr::get(getContext(), true);
896 
897  // Lastly, see if folding can be completed based on what constraints are known
898  // on the input shapes.
899  if ([&] {
901  for (auto shapeValue : getShapes()) {
902  extents.emplace_back();
903  if (failed(getShapeVec(shapeValue, extents.back())))
904  return false;
905  }
907  }())
908  return BoolAttr::get(getContext(), true);
909 
910  // Because a failing witness result here represents an eventual assertion
911  // failure, we do not replace it with a constant witness.
912  return nullptr;
913 }
914 
915 static LogicalResult verify(CstrBroadcastableOp op) {
916  // Ensure that AssumingAllOp contains at least one operand
917  if (op.getNumOperands() < 2)
918  return op.emitOpError("required at least 2 input shapes");
919  return success();
920 }
921 
922 //===----------------------------------------------------------------------===//
923 // CstrEqOp
924 //===----------------------------------------------------------------------===//
925 
926 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
927  MLIRContext *context) {
928  // If inputs are equal, return passing witness
929  patterns.add<CstrEqEqOps>(context);
930 }
931 
932 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
933  if (llvm::all_of(operands,
934  [&](Attribute a) { return a && a == operands[0]; }))
935  return BoolAttr::get(getContext(), true);
936 
937  // Because a failing witness result here represents an eventual assertion
938  // failure, we do not try to replace it with a constant witness. Similarly, we
939  // cannot if there are any non-const inputs.
940  return nullptr;
941 }
942 
943 //===----------------------------------------------------------------------===//
944 // ConstSizeOp
945 //===----------------------------------------------------------------------===//
946 
947 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
948  int64_t value) {
949  build(builder, result, builder.getIndexAttr(value));
950 }
951 
952 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
953 
954 void ConstSizeOp::getAsmResultNames(
955  llvm::function_ref<void(Value, StringRef)> setNameFn) {
956  SmallString<4> buffer;
957  llvm::raw_svector_ostream os(buffer);
958  os << "c" << getValue();
959  setNameFn(getResult(), os.str());
960 }
961 
962 //===----------------------------------------------------------------------===//
963 // ConstWitnessOp
964 //===----------------------------------------------------------------------===//
965 
966 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
967  return getPassingAttr();
968 }
969 
970 //===----------------------------------------------------------------------===//
971 // CstrRequireOp
972 //===----------------------------------------------------------------------===//
973 
974 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
975  return operands[0];
976 }
977 
978 //===----------------------------------------------------------------------===//
979 // DivOp
980 //===----------------------------------------------------------------------===//
981 
982 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
983  auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
984  if (!lhs)
985  return nullptr;
986  auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
987  if (!rhs)
988  return nullptr;
989 
990  // Division in APInt does not follow floor(lhs, rhs) when the result is
991  // negative. Rather, APInt rounds toward zero.
992  APInt quotient, remainder;
993  APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
994  if (quotient.isNegative() && !remainder.isNullValue()) {
995  quotient -= 1;
996  }
997 
998  Type indexTy = IndexType::get(getContext());
999  return IntegerAttr::get(indexTy, quotient);
1000 }
1001 
1002 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1003  MLIRContext *context, Optional<Location> location, ValueRange operands,
1004  DictionaryAttr attributes, RegionRange regions,
1005  SmallVectorImpl<Type> &inferredReturnTypes) {
1006  if (operands[0].getType().isa<SizeType>() ||
1007  operands[1].getType().isa<SizeType>())
1008  inferredReturnTypes.assign({SizeType::get(context)});
1009  else
1010  inferredReturnTypes.assign({IndexType::get(context)});
1011  return success();
1012 }
1013 
1014 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1015  // SizeType is compatible with IndexType.
1016  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1017 }
1018 
1019 //===----------------------------------------------------------------------===//
1020 // ShapeEqOp
1021 //===----------------------------------------------------------------------===//
1022 
1023 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1024  bool allSame = true;
1025  if (!operands.empty() && !operands[0])
1026  return {};
1027  for (Attribute operand : operands.drop_front(1)) {
1028  if (!operand)
1029  return {};
1030  allSame = allSame && operand == operands[0];
1031  }
1032  return BoolAttr::get(getContext(), allSame);
1033 }
1034 
1035 //===----------------------------------------------------------------------===//
1036 // IndexToSizeOp
1037 //===----------------------------------------------------------------------===//
1038 
1039 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1040  // Constant values of both types, `shape.size` and `index`, are represented as
1041  // `IntegerAttr`s which makes constant folding simple.
1042  if (Attribute arg = operands[0])
1043  return arg;
1044  return {};
1045 }
1046 
1047 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1048  MLIRContext *context) {
1049  patterns.add<SizeToIndexToSizeCanonicalization>(context);
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // FromExtentsOp
1054 //===----------------------------------------------------------------------===//
1055 
1056 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1057  if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1058  return nullptr;
1059  SmallVector<int64_t, 6> extents;
1060  for (auto attr : operands)
1061  extents.push_back(attr.cast<IntegerAttr>().getInt());
1062  Builder builder(getContext());
1063  return builder.getIndexTensorAttr(extents);
1064 }
1065 
1066 //===----------------------------------------------------------------------===//
1067 // FunctionLibraryOp
1068 //===----------------------------------------------------------------------===//
1069 
1070 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1071  StringRef name) {
1072  result.attributes.push_back(builder.getNamedAttr(
1074 }
1075 
1076 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1077  auto attr = getMapping()
1078  .get(op->getName().getIdentifier())
1079  .dyn_cast_or_null<FlatSymbolRefAttr>();
1080  if (!attr)
1081  return nullptr;
1082  return lookupSymbol<FuncOp>(attr);
1083 }
1084 
1086  OperationState &result) {
1087  // Parse the op name.
1088  StringAttr nameAttr;
1089  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1090  result.attributes))
1091  return failure();
1092 
1093  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1094  return failure();
1095 
1096  auto *bodyRegion = result.addRegion();
1097  if (parser.parseRegion(*bodyRegion))
1098  return failure();
1099 
1100  if (parser.parseKeyword("mapping"))
1101  return failure();
1102 
1103  DictionaryAttr mappingAttr;
1104  if (parser.parseAttribute(mappingAttr,
1105  parser.getBuilder().getType<NoneType>(), "mapping",
1106  result.attributes))
1107  return failure();
1108  return success();
1109 }
1110 
1111 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1112  p << ' ';
1113  p.printSymbolName(op.getName());
1115  op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1116  p << ' ';
1117  p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1118  /*printBlockTerminators=*/false);
1119  p << " mapping ";
1120  p.printAttributeWithoutType(op.getMappingAttr());
1121 }
1122 
1123 //===----------------------------------------------------------------------===//
1124 // GetExtentOp
1125 //===----------------------------------------------------------------------===//
1126 
1127 Optional<int64_t> GetExtentOp::getConstantDim() {
1128  if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1129  return constSizeOp.getValue().getLimitedValue();
1130  if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1131  return constantOp.getValue().cast<IntegerAttr>().getInt();
1132  return llvm::None;
1133 }
1134 
1135 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1136  auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1137  if (!elements)
1138  return nullptr;
1139  Optional<int64_t> dim = getConstantDim();
1140  if (!dim.hasValue())
1141  return nullptr;
1142  if (dim.getValue() >= elements.getNumElements())
1143  return nullptr;
1144  return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
1145 }
1146 
1147 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1148  int64_t dim) {
1149  auto loc = result.location;
1150  auto dimAttr = builder.getIndexAttr(dim);
1151  if (shape.getType().isa<ShapeType>()) {
1152  Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1153  build(builder, result, builder.getType<SizeType>(), shape, dim);
1154  } else {
1155  Value dim =
1156  builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1157  build(builder, result, builder.getIndexType(), shape, dim);
1158  }
1159 }
1160 
1161 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1162  MLIRContext *context, Optional<Location> location, ValueRange operands,
1163  DictionaryAttr attributes, RegionRange regions,
1164  SmallVectorImpl<Type> &inferredReturnTypes) {
1165  inferredReturnTypes.assign({IndexType::get(context)});
1166  return success();
1167 }
1168 
1169 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1170  TypeRange r) {
1171  // SizeType is compatible with IndexType.
1172  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1173 }
1174 
1175 //===----------------------------------------------------------------------===//
1176 // IsBroadcastableOp
1177 //===----------------------------------------------------------------------===//
1178 
1179 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1180  MLIRContext *context) {
1181  patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1182 }
1183 
1184 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1185  // Can always broadcast fewer than two shapes.
1186  if (operands.size() < 2) {
1187  return BoolAttr::get(getContext(), true);
1188  }
1189 
1190  return nullptr;
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // MeetOp
1195 //===----------------------------------------------------------------------===//
1196 
1197 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1198  MLIRContext *context, Optional<Location> location, ValueRange operands,
1199  DictionaryAttr attributes, RegionRange regions,
1200  SmallVectorImpl<Type> &inferredReturnTypes) {
1201  inferredReturnTypes.assign({operands[0].getType()});
1202  return success();
1203 }
1204 
1205 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1206  if (l.size() != 1 || r.size() != 1)
1207  return false;
1208  if (l == r)
1209  return true;
1210 
1211  Type lhs = l.front();
1212  Type rhs = r.front();
1213 
1214  if (lhs != rhs)
1215  return false;
1216 
1217  if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1218  return true;
1219 
1220  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1221  return true;
1222  return false;
1223 }
1224 
1225 //===----------------------------------------------------------------------===//
1226 // RankOp
1227 //===----------------------------------------------------------------------===//
1228 
1229 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1230  auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1231  if (!shape)
1232  return {};
1233  int64_t rank = shape.getNumElements();
1234  Builder builder(getContext());
1235  return builder.getIndexAttr(rank);
1236 }
1237 
1238 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1239 /// Constant folding fails in cases where only the rank is constant, not the
1240 /// shape itself.
1241 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1242 ///
1243 /// Example:
1244 ///
1245 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1246 /// %rank = shape.rank %shape
1247 ///
1248 /// becomes
1249 ///
1250 /// %rank = shape.const_size 3
1251 
1252 namespace {
1253 struct RankShapeOfCanonicalizationPattern
1254  : public OpRewritePattern<shape::RankOp> {
1256 
1257  LogicalResult matchAndRewrite(shape::RankOp op,
1258  PatternRewriter &rewriter) const override {
1259  auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1260  if (!shapeOfOp)
1261  return failure();
1262  auto rankedTensorType =
1263  shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1264  if (!rankedTensorType)
1265  return failure();
1266  int64_t rank = rankedTensorType.getRank();
1267  if (op.getType().isa<IndexType>()) {
1268  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1269  rank);
1270  } else if (op.getType().isa<shape::SizeType>()) {
1271  rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1272  } else {
1273  return failure();
1274  }
1275  return success();
1276  }
1277 };
1278 } // namespace
1279 
1280 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1281  MLIRContext *context) {
1282  patterns.add<RankShapeOfCanonicalizationPattern>(context);
1283 }
1284 
1285 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1286  MLIRContext *context, Optional<Location> location, ValueRange operands,
1287  DictionaryAttr attributes, RegionRange regions,
1288  SmallVectorImpl<Type> &inferredReturnTypes) {
1289  if (operands[0].getType().isa<ShapeType>())
1290  inferredReturnTypes.assign({SizeType::get(context)});
1291  else
1292  inferredReturnTypes.assign({IndexType::get(context)});
1293  return success();
1294 }
1295 
1296 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1297  // SizeType is compatible with IndexType.
1298  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1299 }
1300 
1301 //===----------------------------------------------------------------------===//
1302 // NumElementsOp
1303 //===----------------------------------------------------------------------===//
1304 
1305 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1306 
1307  // Fold only when argument constant.
1308  Attribute shape = operands[0];
1309  if (!shape)
1310  return {};
1311 
1312  APInt product(64, 1);
1313  for (auto value : shape.cast<DenseIntElementsAttr>())
1314  product *= value;
1315  Builder builder(getContext());
1316  return builder.getIndexAttr(product.getLimitedValue());
1317 }
1318 
1319 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1320  MLIRContext *context, Optional<Location> location, ValueRange operands,
1321  DictionaryAttr attributes, RegionRange regions,
1322  SmallVectorImpl<Type> &inferredReturnTypes) {
1323  if (operands[0].getType().isa<ShapeType>())
1324  inferredReturnTypes.assign({SizeType::get(context)});
1325  else
1326  inferredReturnTypes.assign({IndexType::get(context)});
1327  return success();
1328 }
1329 
1330 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1331  TypeRange r) {
1332  // SizeType is compatible with IndexType.
1333  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1334 }
1335 
1336 //===----------------------------------------------------------------------===//
1337 // MaxOp
1338 //===----------------------------------------------------------------------===//
1339 
1340 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1341  // If operands are equal, just propagate one.
1342  if (getLhs() == getRhs())
1343  return getLhs();
1344  return nullptr;
1345 }
1346 
1347 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1348  MLIRContext *context, Optional<Location> location, ValueRange operands,
1349  DictionaryAttr attributes, RegionRange regions,
1350  SmallVectorImpl<Type> &inferredReturnTypes) {
1351  if (operands[0].getType() == operands[1].getType())
1352  inferredReturnTypes.assign({operands[0].getType()});
1353  else
1354  inferredReturnTypes.assign({SizeType::get(context)});
1355  return success();
1356 }
1357 
1358 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1359  if (l.size() != 1 || r.size() != 1)
1360  return false;
1361  if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1362  return true;
1363  if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1364  return true;
1365  return false;
1366 }
1367 
1368 //===----------------------------------------------------------------------===//
1369 // MinOp
1370 //===----------------------------------------------------------------------===//
1371 
1372 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1373  // If operands are equal, just propagate one.
1374  if (getLhs() == getRhs())
1375  return getLhs();
1376  return nullptr;
1377 }
1378 
1379 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1380  MLIRContext *context, Optional<Location> location, ValueRange operands,
1381  DictionaryAttr attributes, RegionRange regions,
1382  SmallVectorImpl<Type> &inferredReturnTypes) {
1383  if (operands[0].getType() == operands[1].getType())
1384  inferredReturnTypes.assign({operands[0].getType()});
1385  else
1386  inferredReturnTypes.assign({SizeType::get(context)});
1387  return success();
1388 }
1389 
1390 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1391  if (l.size() != 1 || r.size() != 1)
1392  return false;
1393  if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1394  return true;
1395  if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1396  return true;
1397  return false;
1398 }
1399 
1400 //===----------------------------------------------------------------------===//
1401 // MulOp
1402 //===----------------------------------------------------------------------===//
1403 
1404 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1405  auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1406  if (!lhs)
1407  return nullptr;
1408  auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1409  if (!rhs)
1410  return nullptr;
1411  APInt folded = lhs.getValue() * rhs.getValue();
1412  Type indexTy = IndexType::get(getContext());
1413  return IntegerAttr::get(indexTy, folded);
1414 }
1415 
1416 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1417  MLIRContext *context, Optional<Location> location, ValueRange operands,
1418  DictionaryAttr attributes, RegionRange regions,
1419  SmallVectorImpl<Type> &inferredReturnTypes) {
1420  if (operands[0].getType().isa<SizeType>() ||
1421  operands[1].getType().isa<SizeType>())
1422  inferredReturnTypes.assign({SizeType::get(context)});
1423  else
1424  inferredReturnTypes.assign({IndexType::get(context)});
1425  return success();
1426 }
1427 
1428 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1429  // SizeType is compatible with IndexType.
1430  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1431 }
1432 //===----------------------------------------------------------------------===//
1433 // ShapeOfOp
1434 //===----------------------------------------------------------------------===//
1435 
1436 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1437  auto type = getOperand().getType().dyn_cast<ShapedType>();
1438  if (!type || !type.hasStaticShape())
1439  return nullptr;
1440  Builder builder(getContext());
1441  return builder.getIndexTensorAttr(type.getShape());
1442 }
1443 
1444 namespace {
1445 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1447 
1448  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1449  PatternRewriter &rewriter) const override {
1450  if (!op.getArg().getType().isa<ShapedType>())
1451  return failure();
1452  if (op.getType().isa<ShapedType>())
1453  return failure();
1454 
1455  rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1456  op.getArg());
1457  return success();
1458  }
1459 };
1460 
1461 // Canonicalize
1462 // ```
1463 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1464 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1465 // ```
1466 // to
1467 // ```
1468 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1469 // ```
1470 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1472 
1473  LogicalResult matchAndRewrite(tensor::CastOp op,
1474  PatternRewriter &rewriter) const override {
1475  auto ty = op.getType().dyn_cast<RankedTensorType>();
1476  if (!ty || ty.getRank() != 1)
1477  return failure();
1478 
1479  auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1480  if (!shapeOfOp)
1481  return failure();
1482 
1483  // Argument type must be ranked and must not conflict.
1484  auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1485  if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1486  return failure();
1487 
1488  rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1489  return success();
1490  }
1491 };
1492 } // namespace
1493 
1494 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1495  MLIRContext *context) {
1496  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1497  ExtractFromShapeOfExtentTensor>(context);
1498 }
1499 
1500 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1501  MLIRContext *context, Optional<Location> location, ValueRange operands,
1502  DictionaryAttr attributes, RegionRange regions,
1503  SmallVectorImpl<Type> &inferredReturnTypes) {
1504  if (operands[0].getType().isa<ValueShapeType>())
1505  inferredReturnTypes.assign({ShapeType::get(context)});
1506  else {
1507  auto shapedTy = operands[0].getType().cast<ShapedType>();
1508  int64_t rank =
1509  shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1510  Type indexTy = IndexType::get(context);
1511  Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1512  inferredReturnTypes.assign({extentTensorTy});
1513  }
1514  return success();
1515 }
1516 
1517 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1518  if (l.size() != 1 || r.size() != 1)
1519  return false;
1520  if (l == r)
1521  return true;
1522 
1523  Type lhs = l.front();
1524  Type rhs = r.front();
1525 
1526  if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1527  return false;
1528 
1529  if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1530  // Shape type is compatible with all other valid return types.
1531  return true;
1532 
1533  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1534  return true;
1535  return false;
1536 }
1537 
1538 //===----------------------------------------------------------------------===//
1539 // SizeToIndexOp
1540 //===----------------------------------------------------------------------===//
1541 
1542 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1543  // Constant values of both types, `shape.size` and `index`, are represented as
1544  // `IntegerAttr`s which makes constant folding simple.
1545  if (Attribute arg = operands[0])
1546  return arg;
1547  return impl::foldCastOp(*this);
1548 }
1549 
1550 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1551  MLIRContext *context) {
1552  patterns.add<IndexToSizeToIndexCanonicalization>(context);
1553 }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // YieldOp
1557 //===----------------------------------------------------------------------===//
1558 
1559 static LogicalResult verify(shape::YieldOp op) {
1560  auto *parentOp = op->getParentOp();
1561  auto results = parentOp->getResults();
1562  auto operands = op.getOperands();
1563 
1564  if (parentOp->getNumResults() != op.getNumOperands())
1565  return op.emitOpError() << "number of operands does not match number of "
1566  "results of its parent";
1567  for (auto e : llvm::zip(results, operands))
1568  if (std::get<0>(e).getType() != std::get<1>(e).getType())
1569  return op.emitOpError()
1570  << "types mismatch between yield op and its parent";
1571 
1572  return success();
1573 }
1574 
1575 //===----------------------------------------------------------------------===//
1576 // SplitAtOp
1577 //===----------------------------------------------------------------------===//
1578 
1579 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1580  SmallVectorImpl<OpFoldResult> &results) {
1581  if (!operands[0] || !operands[1])
1582  return failure();
1583  auto shapeVec = llvm::to_vector<6>(
1584  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1585  auto shape = llvm::makeArrayRef(shapeVec);
1586  auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1587  // Verify that the split point is in the correct range.
1588  // TODO: Constant fold to an "error".
1589  int64_t rank = shape.size();
1590  if (!(-rank <= splitPoint && splitPoint <= rank))
1591  return failure();
1592  if (splitPoint < 0)
1593  splitPoint += shape.size();
1594  Builder builder(operands[0].getContext());
1595  results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1596  results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1597  return success();
1598 }
1599 
1600 //===----------------------------------------------------------------------===//
1601 // ToExtentTensorOp
1602 //===----------------------------------------------------------------------===//
1603 
1604 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1605  if (!operands[0])
1606  return impl::foldCastOp(*this);
1607  Builder builder(getContext());
1608  auto shape = llvm::to_vector<6>(
1609  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1610  auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1611  builder.getIndexType());
1612  return DenseIntElementsAttr::get(type, shape);
1613 }
1614 
1615 //===----------------------------------------------------------------------===//
1616 // ReduceOp
1617 //===----------------------------------------------------------------------===//
1618 
1619 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1620  ValueRange initVals) {
1621  result.addOperands(shape);
1622  result.addOperands(initVals);
1623 
1624  Region *bodyRegion = result.addRegion();
1625  bodyRegion->push_back(new Block);
1626  Block &bodyBlock = bodyRegion->front();
1627  bodyBlock.addArgument(builder.getIndexType(), result.location);
1628 
1629  Type elementType;
1630  if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1631  elementType = tensorType.getElementType();
1632  else
1633  elementType = SizeType::get(builder.getContext());
1634  bodyBlock.addArgument(elementType, shape.getLoc());
1635 
1636  for (Value initVal : initVals) {
1637  bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1638  result.addTypes(initVal.getType());
1639  }
1640 }
1641 
1642 static LogicalResult verify(ReduceOp op) {
1643  // Verify block arg types.
1644  Block &block = op.getRegion().front();
1645 
1646  // The block takes index, extent, and aggregated values as arguments.
1647  auto blockArgsCount = op.getInitVals().size() + 2;
1648  if (block.getNumArguments() != blockArgsCount)
1649  return op.emitOpError() << "ReduceOp body is expected to have "
1650  << blockArgsCount << " arguments";
1651 
1652  // The first block argument is the index and must always be of type `index`.
1653  if (!block.getArgument(0).getType().isa<IndexType>())
1654  return op.emitOpError(
1655  "argument 0 of ReduceOp body is expected to be of IndexType");
1656 
1657  // The second block argument is the extent and must be of type `size` or
1658  // `index`, depending on whether the reduce operation is applied to a shape or
1659  // to an extent tensor.
1660  Type extentTy = block.getArgument(1).getType();
1661  if (op.getShape().getType().isa<ShapeType>()) {
1662  if (!extentTy.isa<SizeType>())
1663  return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1664  "SizeType if the ReduceOp operates on a ShapeType");
1665  } else {
1666  if (!extentTy.isa<IndexType>())
1667  return op.emitOpError(
1668  "argument 1 of ReduceOp body is expected to be of IndexType if the "
1669  "ReduceOp operates on an extent tensor");
1670  }
1671 
1672  for (const auto &type : llvm::enumerate(op.getInitVals()))
1673  if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1674  return op.emitOpError()
1675  << "type mismatch between argument " << type.index() + 2
1676  << " of ReduceOp body and initial value " << type.index();
1677  return success();
1678 }
1679 
1681  // Parse operands.
1683  Type shapeOrExtentTensorType;
1684  if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1686  parser.parseColonType(shapeOrExtentTensorType) ||
1687  parser.parseOptionalArrowTypeList(result.types))
1688  return failure();
1689 
1690  // Resolve operands.
1691  auto initVals = llvm::makeArrayRef(operands).drop_front();
1692  if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1693  result.operands) ||
1694  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1695  result.operands))
1696  return failure();
1697 
1698  // Parse the body.
1699  Region *body = result.addRegion();
1700  if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1701  return failure();
1702 
1703  // Parse attributes.
1704  if (parser.parseOptionalAttrDict(result.attributes))
1705  return failure();
1706 
1707  return success();
1708 }
1709 
1710 static void print(OpAsmPrinter &p, ReduceOp op) {
1711  p << '(' << op.getShape() << ", " << op.getInitVals()
1712  << ") : " << op.getShape().getType();
1713  p.printOptionalArrowTypeList(op.getResultTypes());
1714  p << ' ';
1715  p.printRegion(op.getRegion());
1716  p.printOptionalAttrDict(op->getAttrs());
1717 }
1718 
1719 #define GET_OP_CLASSES
1720 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within &#39;src&#39; can be inlined.
virtual ParseResult parseOperand(OperandType &result)=0
Parse a single operand.
This is the representation of an operand reference.
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition: Shape.cpp:85
ParseResult resolveOperands(ArrayRef< OperandType > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:55
MLIRContext * getContext() const
Definition: Builders.h:54
U cast() const
Definition: Attributes.h:123
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:373
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:151
Ty getType(Args... args)
Get or construct an instance of the type &#39;ty&#39; with provided arguments.
Definition: Builders.h:83
Operation & back()
Definition: Block.h:143
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:141
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
virtual void printType(Type type)
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of &#39;symbolTableOp&#39;.
static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result)
Definition: Shape.cpp:1680
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:244
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block...
void push_back(Block *block)
Definition: Region.h:61
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
operand_type_range getOperandTypes()
Definition: Operation.h:266
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
The Witness represents a runtime constraint, to be used as shape related preconditions on code execut...
Definition: Shape.h:64
Operation & front()
Definition: Block.h:144
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:81
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
static constexpr const bool value
ParseResult parseSymbolName(StringAttr &result, StringRef attrName, NamedAttrList &attrs)
Parse an -identifier and store it (without the &#39;@&#39; symbol) in a string attribute named &#39;attrName&#39;...
static ConcreteT get(MLIRContext *ctx, Args... args)
Get or create a new ConcreteT instance within the ctx.
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
virtual ParseResult parseOperandList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:137
LogicalResult emitOptionalError(Optional< Location > loc, Args &&... args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:464
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
bool getValue() const
Return the boolean value of this attribute.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual ParseResult resolveOperand(const OperandType &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition: Shape.cpp:98
The type of a single dimension.
Definition: Shape.h:50
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual ParseResult parseRegion(Region &region, ArrayRef< OperandType > arguments={}, ArrayRef< Type > argTypes={}, ArrayRef< Location > argLocations={}, bool enableNameShadowing=false)=0
Parses a region.
void addOperands(ValueRange newOperands)
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
U dyn_cast() const
Definition: Types.h:244
unsigned getNumArguments()
Definition: Block.h:119
Attributes are known-constant values of operations.
Definition: Attributes.h:24
bool isExtentTensorType(Type)
Definition: Shape.cpp:41
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:470
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:44
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:32
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition: Traits.cpp:24
This is the interface that must be implemented by the dialects of operations to be inlined...
Definition: InliningUtils.h:41
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:338
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition: Shape.cpp:67
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
void addTypes(ArrayRef< Type > newTypes)
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This represents an operation in an abstracted form, suitable for use with the builder APIs...
Parens surrounding zero or more operands.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
auto getType() const
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if the attributes keyword is present.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
Value foldCastOp(Operation *op)
Definition: Operation.cpp:1259
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition: Shape.cpp:46
static void print(OpAsmPrinter &p, AssumingOp op)
Definition: Shape.cpp:285
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
The ValueShape represents a (potentially unknown) runtime value and shape.
Definition: Shape.h:56
static ParseResult parseConstShapeOp(OpAsmParser &parser, OperationState &result)
Definition: Shape.cpp:787
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
Type front()
Return first type in the range.
Definition: TypeRange.h:158
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition: Shape.cpp:73
ParseResult parseFunctionLibraryOp(OpAsmParser &parser, OperationState &result)
Definition: Shape.cpp:1085
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
IndexType getIndexType()
Definition: Builders.cpp:48
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:59
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamicSize)
Alias type for extent tensors.
Definition: Shape.cpp:37
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
U dyn_cast() const
Definition: Attributes.h:117
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with &#39;attribute...
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:322
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition: Shape.cpp:868
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:367
detail::constant_int_value_matcher< 0 > m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:254
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
static ParseResult parseAssumingOp(OpAsmParser &parser, OperationState &result)
Definition: Shape.cpp:258
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:376
static BoolAttr get(MLIRContext *context, bool value)
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
bool isa() const
Definition: Types.h:234
The shape descriptor type represents rank and dimension sizes.
Definition: Shape.h:44
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block &#39;source&#39; into the end of block &#39;dest&#39;.
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:196
result_type_range getResultTypes()
Definition: Operation.h:297
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
An attribute that represents a reference to a dense integer vector or tensor object.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)
SmallVector< Type, 4 > types
Types of the results of this operation.