MLIR  15.0.0git
Shape.cpp
Go to the documentation of this file.
1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 
16 #include "mlir/Dialect/Traits.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 #include "llvm/ADT/SetOperations.h"
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 using namespace mlir;
31 using namespace mlir::shape;
32 
33 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
34 
35 namespace {
36 #include "ShapeCanonicalization.inc"
37 } // namespace
38 
39 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
40  return RankedTensorType::get({rank}, IndexType::get(ctx));
41 }
42 
44  auto ranked = type.dyn_cast<RankedTensorType>();
45  return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
46 }
47 
49  SmallVectorImpl<int64_t> &shapeValues) {
50  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
51  auto type = inputOp.getArg().getType().cast<ShapedType>();
52  if (!type.hasRank())
53  return failure();
54  llvm::append_range(shapeValues, type.getShape());
55  return success();
56  }
58  if (matchPattern(input, m_Constant(&attr))) {
59  llvm::append_range(shapeValues, attr.getValues<int64_t>());
60  return success();
61  }
62  return failure();
63 }
64 
65 static bool isErrorPropagationPossible(TypeRange operandTypes) {
66  return llvm::any_of(operandTypes, [](Type ty) {
67  return ty.isa<SizeType, ShapeType, ValueShapeType>();
68  });
69 }
70 
72  assert(op != nullptr && op->getNumResults() == 1);
73  Type resultTy = op->getResultTypes().front();
75  if (!resultTy.isa<SizeType>())
76  return op->emitOpError()
77  << "if at least one of the operands can hold error values then "
78  "the result must be of type `size` to propagate them";
79  }
80  return success();
81 }
82 
84  assert(op != nullptr && op->getNumResults() == 1);
85  Type resultTy = op->getResultTypes().front();
87  if (!resultTy.isa<ShapeType>())
88  return op->emitOpError()
89  << "if at least one of the operands can hold error values then "
90  "the result must be of type `shape` to propagate them";
91  }
92  return success();
93 }
94 
95 template <typename... Ty>
96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97  return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
98 }
99 
100 template <typename... Ty, typename... ranges>
101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102  return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // InlinerInterface
107 //===----------------------------------------------------------------------===//
108 
109 namespace {
110 /// This class defines the interface for inlining shape dialect ops.
111 struct ShapeInlinerInterface : public DialectInlinerInterface {
113 
114  // Returns true if the given region 'src' can be inlined into the region
115  // 'dest' that is attached to an operation registered to the current dialect.
116  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117  BlockAndValueMapping &) const final {
118  return true;
119  }
120 
121  // Returns true if the given operation 'op', that is registered to this
122  // dialect, can be inlined into the region 'dest' that is attached to an
123  // operation registered to the current dialect.
124  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125  BlockAndValueMapping &) const final {
126  return true;
127  }
128 };
129 } // namespace
130 
131 void ShapeDialect::initialize() {
132  addOperations<
133 #define GET_OP_LIST
134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135  >();
136  addTypes<
137 #define GET_TYPEDEF_LIST
138 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
139  >();
140  addInterfaces<ShapeInlinerInterface>();
141  // Allow unknown operations during prototyping and testing. As the dialect is
142  // still evolving it makes it simple to start with an unregistered ops and
143  // try different variants before actually defining the op.
144  allowUnknownOperations();
145 }
146 
148  Attribute value, Type type,
149  Location loc) {
150  if (type.isa<ShapeType>() || isExtentTensorType(type))
151  return builder.create<ConstShapeOp>(loc, type,
152  value.cast<DenseIntElementsAttr>());
153  if (type.isa<SizeType>())
154  return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
155  if (type.isa<WitnessType>())
156  return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
157  if (arith::ConstantOp::isBuildableWith(value, type))
158  return builder.create<arith::ConstantOp>(loc, type, value);
159  return nullptr;
160 }
161 
162 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
163  NamedAttribute attribute) {
164  // Verify shape.lib attribute.
165  if (attribute.getName() == "shape.lib") {
166  if (!op->hasTrait<OpTrait::SymbolTable>())
167  return op->emitError(
168  "shape.lib attribute may only be on op implementing SymbolTable");
169 
170  if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
171  auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
172  if (!symbol)
173  return op->emitError("shape function library ")
174  << symbolRef << " not found";
175  return isa<shape::FunctionLibraryOp>(symbol)
176  ? success()
177  : op->emitError()
178  << symbolRef << " required to be shape function library";
179  }
180 
181  if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
182  // Verify all entries are function libraries and mappings in libraries
183  // refer to unique ops.
185  for (auto it : arr) {
186  if (!it.isa<SymbolRefAttr>())
187  return op->emitError(
188  "only SymbolRefAttr allowed in shape.lib attribute array");
189 
190  auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
191  SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
192  if (!shapeFnLib)
193  return op->emitError()
194  << it << " does not refer to FunctionLibraryOp";
195  for (auto mapping : shapeFnLib.getMapping()) {
196  if (!key.insert(mapping.getName()).second) {
197  return op->emitError("only one op to shape mapping allowed, found "
198  "multiple for `")
199  << mapping.getName() << "`";
200  }
201  }
202  }
203  return success();
204  }
205 
206  return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
207  "allowed as shape.lib attribute");
208  }
209  return success();
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // AnyOp
214 //===----------------------------------------------------------------------===//
215 
216 // TODO: Canonicalization should be implemented for shapes that can be
217 // determined through mixtures of the known dimensions of the inputs.
218 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
219  // Only the last operand is checked because AnyOp is commutative.
220  if (operands.back())
221  return operands.back();
222 
223  return nullptr;
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // AssumingOp
228 //===----------------------------------------------------------------------===//
229 
230 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
231  result.regions.reserve(1);
232  Region *doRegion = result.addRegion();
233 
234  auto &builder = parser.getBuilder();
236  if (parser.parseOperand(cond) ||
237  parser.resolveOperand(cond, builder.getType<WitnessType>(),
238  result.operands))
239  return failure();
240 
241  // Parse optional results type list.
242  if (parser.parseOptionalArrowTypeList(result.types))
243  return failure();
244 
245  // Parse the region and add a terminator if elided.
246  if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
247  return failure();
248  AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
249 
250  // Parse the optional attribute list.
251  if (parser.parseOptionalAttrDict(result.attributes))
252  return failure();
253  return success();
254 }
255 
257  bool yieldsResults = !getResults().empty();
258 
259  p << " " << getWitness();
260  if (yieldsResults)
261  p << " -> (" << getResultTypes() << ")";
262  p << ' ';
263  p.printRegion(getDoRegion(),
264  /*printEntryBlockArgs=*/false,
265  /*printBlockTerminators=*/yieldsResults);
266  p.printOptionalAttrDict((*this)->getAttrs());
267 }
268 
269 namespace {
270 // Removes AssumingOp with a passing witness and inlines the region.
271 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
273 
274  LogicalResult matchAndRewrite(AssumingOp op,
275  PatternRewriter &rewriter) const override {
276  auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
277  if (!witness || !witness.getPassingAttr())
278  return failure();
279 
280  AssumingOp::inlineRegionIntoParent(op, rewriter);
281  return success();
282  }
283 };
284 
285 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
287 
288  LogicalResult matchAndRewrite(AssumingOp op,
289  PatternRewriter &rewriter) const override {
290  Block *body = op.getBody();
291  auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
292 
293  // Find used values.
294  SmallVector<Value, 4> newYieldOperands;
295  Value opResult, yieldOperand;
296  for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
297  std::tie(opResult, yieldOperand) = it;
298  if (!opResult.getUses().empty()) {
299  newYieldOperands.push_back(yieldOperand);
300  }
301  }
302 
303  // Rewrite only if redundant results exist.
304  if (newYieldOperands.size() == yieldOp->getNumOperands())
305  return failure();
306 
307  // Replace yield op in the old assuming op's body and move the entire region
308  // to the new assuming op.
309  rewriter.setInsertionPointToEnd(body);
310  auto newYieldOp =
311  rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
312  rewriter.setInsertionPoint(op);
313  auto newOp = rewriter.create<AssumingOp>(
314  op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
315  newOp.getDoRegion().takeBody(op.getDoRegion());
316 
317  // Use the new results to replace the previously used ones.
318  SmallVector<Value, 4> replacementValues;
319  auto src = newOp.getResults().begin();
320  for (auto it : op.getResults()) {
321  if (it.getUses().empty())
322  replacementValues.push_back(nullptr);
323  else
324  replacementValues.push_back(*src++);
325  }
326  rewriter.replaceOp(op, replacementValues);
327  return success();
328  }
329 };
330 } // namespace
331 
332 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
333  MLIRContext *context) {
334  patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
335 }
336 
337 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
338 void AssumingOp::getSuccessorRegions(
339  Optional<unsigned> index, ArrayRef<Attribute> operands,
341  // AssumingOp has unconditional control flow into the region and back to the
342  // parent, so return the correct RegionSuccessor purely based on the index
343  // being None or 0.
344  if (index) {
345  regions.push_back(RegionSuccessor(getResults()));
346  return;
347  }
348 
349  regions.push_back(RegionSuccessor(&getDoRegion()));
350 }
351 
352 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
353  PatternRewriter &rewriter) {
354  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
355  auto *assumingBlock = op.getBody();
356  auto initPosition = rewriter.getInsertionPoint();
357  auto *blockAfterAssuming =
358  rewriter.splitBlock(blockBeforeAssuming, initPosition);
359 
360  // Remove the AssumingOp and AssumingYieldOp.
361  auto &yieldOp = assumingBlock->back();
362  rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
363  rewriter.replaceOp(op, yieldOp.getOperands());
364  rewriter.eraseOp(&yieldOp);
365 
366  // Merge blocks together as there was no branching behavior from the
367  // AssumingOp.
368  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
369  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
370 }
371 
372 void AssumingOp::build(
373  OpBuilder &builder, OperationState &result, Value witness,
375 
376  result.addOperands(witness);
377  Region *bodyRegion = result.addRegion();
378  bodyRegion->push_back(new Block);
379  Block &bodyBlock = bodyRegion->front();
380 
381  // Build body.
382  OpBuilder::InsertionGuard guard(builder);
383  builder.setInsertionPointToStart(&bodyBlock);
384  SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
385  builder.create<AssumingYieldOp>(result.location, yieldValues);
386 
387  SmallVector<Type, 2> assumingTypes;
388  for (Value v : yieldValues)
389  assumingTypes.push_back(v.getType());
390  result.addTypes(assumingTypes);
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // AddOp
395 //===----------------------------------------------------------------------===//
396 
397 LogicalResult mlir::shape::AddOp::inferReturnTypes(
398  MLIRContext *context, Optional<Location> location, ValueRange operands,
399  DictionaryAttr attributes, RegionRange regions,
400  SmallVectorImpl<Type> &inferredReturnTypes) {
401  if (operands[0].getType().isa<SizeType>() ||
402  operands[1].getType().isa<SizeType>())
403  inferredReturnTypes.assign({SizeType::get(context)});
404  else
405  inferredReturnTypes.assign({IndexType::get(context)});
406  return success();
407 }
408 
409 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
410  // SizeType is compatible with IndexType.
411  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
412 }
413 
414 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
415  // add(x, 0) -> x
416  if (matchPattern(getRhs(), m_Zero()))
417  return getLhs();
418 
419  return constFoldBinaryOp<IntegerAttr>(
420  operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
421 }
422 
424 
425 //===----------------------------------------------------------------------===//
426 // AssumingAllOp
427 //===----------------------------------------------------------------------===//
428 
429 namespace {
430 
431 // Merge multiple `shape.assuming_all` operations together.
432 //
433 // %0 = shape.assuming_all %w0, %w1
434 // %1 = shape.assuming_all %w2, %0
435 //
436 // to:
437 //
438 // %0 = shape.assuming_all %w0, %w2, %w2
439 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
441 
442  LogicalResult matchAndRewrite(AssumingAllOp op,
443  PatternRewriter &rewriter) const override {
444  SmallVector<Value> operands;
445 
446  for (Value operand : op.getInputs()) {
447  if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
448  operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
449  else
450  operands.push_back(operand);
451  }
452 
453  // We didn't find any other `assuming_all` ops to merge with.
454  if (operands.size() == op.getNumOperands())
455  return failure();
456 
457  // Replace with a new `assuming_all` operation with merged constraints.
458  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
459  return success();
460  }
461 };
462 
463 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
464 // are subsumed by others.
465 //
466 // %0 = shape.cstr_broadcastable %shape0, %shape1
467 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
468 //
469 // %2 = shape.cstr_broadcastable %shape3, %shape4
470 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
471 //
472 // %4 = shape.assuming_all %0, %1, %2, %3
473 //
474 // to:
475 //
476 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
477 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
478 // %2 = shape.assuming_all %0, %1
479 //
480 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
481 // shapes [0, 1] are broadcastable too, and can be removed from the list of
482 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
483 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
484 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
486 
487  LogicalResult matchAndRewrite(AssumingAllOp op,
488  PatternRewriter &rewriter) const override {
489  // Collect all `CstrBroadcastableOp` operands first.
491  for (Value operand : op.getInputs()) {
492  // TODO: Apply this optimization if some of the witnesses are not
493  // produced by the `cstr_broadcastable`.
494  auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
495  if (!broadcastable)
496  return failure();
497 
498  operands.insert(broadcastable);
499  }
500 
501  // Skip trivial `assuming_all` operations.
502  if (operands.size() <= 1)
503  return failure();
504 
505  // Collect shapes checked by `cstr_broadcastable` operands.
507  for (auto cstr : operands) {
508  DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
509  shapes.emplace_back(cstr, std::move(shapesSet));
510  }
511 
512  // Sort by the number of shape operands (larger to smaller).
513  llvm::sort(shapes, [](auto a, auto b) {
514  return a.first.getNumOperands() > b.first.getNumOperands();
515  });
516 
517  // We start from the `cst_broadcastable` operations with largest number of
518  // shape operands, and remove redundant `cst_broadcastable` operations. We
519  // do this until we find a set of `cst_broadcastable` operations with
520  // non-overlapping constraints.
521  SmallVector<CstrBroadcastableOp> markedForErase;
522 
523  for (unsigned i = 0; i < shapes.size(); ++i) {
524  auto isSubset = [&](auto pair) {
525  return llvm::set_is_subset(pair.second, shapes[i].second);
526  };
527 
528  // Keep redundant `cstr_broadcastable` operations to be erased.
529  auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
530  for (auto *it0 = it; it0 < shapes.end(); ++it0)
531  markedForErase.push_back(it0->first);
532  shapes.erase(it, shapes.end());
533  }
534 
535  // We didn't find any operands that could be removed.
536  if (markedForErase.empty())
537  return failure();
538 
539  // Collect non-overlapping `cst_broadcastable` constraints.
540  SmallVector<Value> uniqueConstraints;
541  for (auto &shape : shapes)
542  uniqueConstraints.push_back(shape.first.getResult());
543 
544  // Replace with a new `assuming_all` operation ...
545  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
546 
547  // ... and maybe erase `cstr_broadcastable` ops without uses.
548  for (auto &op : markedForErase)
549  if (op->use_empty())
550  rewriter.eraseOp(op);
551 
552  return success();
553  }
554 };
555 
556 struct AssumingAllToCstrEqCanonicalization
557  : public OpRewritePattern<AssumingAllOp> {
559 
560  LogicalResult matchAndRewrite(AssumingAllOp op,
561  PatternRewriter &rewriter) const override {
562  SmallVector<Value, 8> shapes;
563  for (Value w : op.getInputs()) {
564  auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
565  if (!cstrEqOp)
566  return failure();
567  bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
568  return llvm::is_contained(shapes, s);
569  });
570  if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
571  return failure();
572  shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
573  }
574  rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
575  return success();
576  }
577 };
578 
579 template <typename OpTy>
580 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
582 
583  LogicalResult matchAndRewrite(OpTy op,
584  PatternRewriter &rewriter) const override {
585  // Find unique operands.
586  SetVector<Value> unique(op.operand_begin(), op.operand_end());
587 
588  // Reduce op to equivalent with unique operands.
589  if (unique.size() < op.getNumOperands()) {
590  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
591  unique.takeVector(), op->getAttrs());
592  return success();
593  }
594 
595  return failure();
596  }
597 };
598 } // namespace
599 
600 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
601  MLIRContext *context) {
602  patterns
603  .add<MergeAssumingAllOps, AssumingAllOneOp,
604  AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
605  RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
606 }
607 
608 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
609  // Iterate in reverse to first handle all constant operands. They are
610  // guaranteed to be the tail of the inputs because this is commutative.
611  for (int idx = operands.size() - 1; idx >= 0; idx--) {
612  Attribute a = operands[idx];
613  // Cannot fold if any inputs are not constant;
614  if (!a)
615  return nullptr;
616 
617  // We do not need to keep statically known values after handling them in
618  // this method.
619  getOperation()->eraseOperand(idx);
620 
621  // Always false if any input is statically known false
622  if (!a.cast<BoolAttr>().getValue())
623  return a;
624  }
625  // If this is reached, all inputs were statically known passing.
626  return BoolAttr::get(getContext(), true);
627 }
628 
630  // Ensure that AssumingAllOp contains at least one operand
631  if (getNumOperands() == 0)
632  return emitOpError("no operands specified");
633 
634  return success();
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // BroadcastOp
639 //===----------------------------------------------------------------------===//
640 
641 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
642  if (getShapes().size() == 1) {
643  // Otherwise, we need a cast which would be a canonicalization, not folding.
644  if (getShapes().front().getType() != getType())
645  return nullptr;
646  return getShapes().front();
647  }
648 
649  // TODO: Support folding with more than 2 input shapes
650  if (getShapes().size() > 2)
651  return nullptr;
652 
653  if (!operands[0] || !operands[1])
654  return nullptr;
655  auto lhsShape = llvm::to_vector<6>(
656  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
657  auto rhsShape = llvm::to_vector<6>(
658  operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
659  SmallVector<int64_t, 6> resultShape;
660 
661  // If the shapes are not compatible, we can't fold it.
662  // TODO: Fold to an "error".
663  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
664  return nullptr;
665 
666  Builder builder(getContext());
667  return builder.getIndexTensorAttr(resultShape);
668 }
669 
671  return verifyShapeOrExtentTensorOp(*this);
672 }
673 
674 namespace {
675 template <typename OpTy>
676 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
678 
679  LogicalResult matchAndRewrite(OpTy op,
680  PatternRewriter &rewriter) const override {
681  auto isPotentiallyNonEmptyShape = [](Value shape) {
682  if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
683  if (extentTensorTy.getDimSize(0) == 0)
684  return false;
685  }
686  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
687  if (constShape.getShape().empty())
688  return false;
689  }
690  return true;
691  };
692  auto newOperands = llvm::to_vector<8>(
693  llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
694 
695  // Reduce op to equivalent without empty shape operands.
696  if (newOperands.size() < op.getNumOperands()) {
697  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
698  op->getAttrs());
699  return success();
700  }
701 
702  return failure();
703  }
704 };
705 
706 struct BroadcastForwardSingleOperandPattern
707  : public OpRewritePattern<BroadcastOp> {
709 
710  LogicalResult matchAndRewrite(BroadcastOp op,
711  PatternRewriter &rewriter) const override {
712  if (op.getNumOperands() != 1)
713  return failure();
714  Value replacement = op.getShapes().front();
715 
716  // Insert cast if needed.
717  if (replacement.getType() != op.getType()) {
718  auto loc = op.getLoc();
719  if (op.getType().isa<ShapeType>()) {
720  replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
721  } else {
722  assert(!op.getType().isa<ShapeType>() &&
723  !replacement.getType().isa<ShapeType>() &&
724  "expect extent tensor cast");
725  replacement =
726  rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
727  }
728  }
729 
730  rewriter.replaceOp(op, replacement);
731  return success();
732  }
733 };
734 
735 struct BroadcastFoldConstantOperandsPattern
736  : public OpRewritePattern<BroadcastOp> {
738 
739  LogicalResult matchAndRewrite(BroadcastOp op,
740  PatternRewriter &rewriter) const override {
741  SmallVector<int64_t, 8> foldedConstantShape;
742  SmallVector<Value, 8> newShapeOperands;
743  for (Value shape : op.getShapes()) {
744  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
745  SmallVector<int64_t, 8> newFoldedConstantShape;
747  foldedConstantShape,
748  llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
749  newFoldedConstantShape)) {
750  foldedConstantShape = newFoldedConstantShape;
751  continue;
752  }
753  }
754  newShapeOperands.push_back(shape);
755  }
756 
757  // Need at least two constant operands to fold anything.
758  if (op.getNumOperands() - newShapeOperands.size() < 2)
759  return failure();
760 
761  auto foldedConstantOperandsTy = RankedTensorType::get(
762  {static_cast<int64_t>(foldedConstantShape.size())},
763  rewriter.getIndexType());
764  newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
765  op.getLoc(), foldedConstantOperandsTy,
766  rewriter.getIndexTensorAttr(foldedConstantShape)));
767  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
768  newShapeOperands);
769  return success();
770  }
771 };
772 
773 template <typename OpTy>
774 struct CanonicalizeCastExtentTensorOperandsPattern
775  : public OpRewritePattern<OpTy> {
777 
778  LogicalResult matchAndRewrite(OpTy op,
779  PatternRewriter &rewriter) const override {
780  // Canonicalize operands.
781  bool anyChange = false;
782  auto canonicalizeOperand = [&](Value operand) {
783  if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
784  // Only eliminate the cast if it holds no shape information.
785  bool isInformationLoosingCast =
786  castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
787  if (isInformationLoosingCast) {
788  anyChange = true;
789  return castOp.getSource();
790  }
791  }
792  return operand;
793  };
794  auto newOperands = llvm::to_vector<8>(
795  llvm::map_range(op.getOperands(), canonicalizeOperand));
796 
797  // Rewrite op if any change required.
798  if (!anyChange)
799  return failure();
800  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
801  return success();
802  }
803 };
804 
805 struct BroadcastConcretizeResultTypePattern
806  : public OpRewritePattern<BroadcastOp> {
808 
809  LogicalResult matchAndRewrite(BroadcastOp op,
810  PatternRewriter &rewriter) const override {
811  // Only concretize dynamic extent tensor result types.
812  auto resultTy = op.getType().dyn_cast<RankedTensorType>();
813  if (!resultTy || !resultTy.isDynamicDim(0))
814  return failure();
815 
816  // Infer resulting shape rank if possible.
817  int64_t maxRank = 0;
818  for (Value shape : op.getShapes()) {
819  if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
820  // Cannot infer resulting shape rank if any operand is dynamically
821  // ranked.
822  if (extentTensorTy.isDynamicDim(0))
823  return failure();
824  maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
825  }
826  }
827 
828  auto newOp = rewriter.create<BroadcastOp>(
829  op.getLoc(), getExtentTensorType(getContext(), maxRank),
830  op.getShapes());
831  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
832  return success();
833  }
834 };
835 } // namespace
836 
837 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
838  MLIRContext *context) {
839  patterns.add<BroadcastConcretizeResultTypePattern,
840  BroadcastFoldConstantOperandsPattern,
841  BroadcastForwardSingleOperandPattern,
842  CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
843  RemoveDuplicateOperandsPattern<BroadcastOp>,
844  RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // ConcatOp
849 //===----------------------------------------------------------------------===//
850 
851 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
852  if (!operands[0] || !operands[1])
853  return nullptr;
854  auto lhsShape = llvm::to_vector<6>(
855  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
856  auto rhsShape = llvm::to_vector<6>(
857  operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
858  SmallVector<int64_t, 6> resultShape;
859  resultShape.append(lhsShape.begin(), lhsShape.end());
860  resultShape.append(rhsShape.begin(), rhsShape.end());
861  Builder builder(getContext());
862  return builder.getIndexTensorAttr(resultShape);
863 }
864 
865 //===----------------------------------------------------------------------===//
866 // ConstShapeOp
867 //===----------------------------------------------------------------------===//
868 
870  p << " ";
871  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
872  p << "[";
873  interleaveComma(getShape().getValues<int64_t>(), p);
874  p << "] : ";
875  p.printType(getType());
876 }
877 
878 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
879  if (parser.parseOptionalAttrDict(result.attributes))
880  return failure();
881  // We piggy-back on ArrayAttr parsing, though we don't internally store the
882  // shape as an ArrayAttr.
883  // TODO: Implement custom parser and maybe make syntax a bit more concise.
884  Attribute extentsRaw;
885  NamedAttrList dummy;
886  if (parser.parseAttribute(extentsRaw, "dummy", dummy))
887  return failure();
888  auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
889  if (!extentsArray)
890  return failure();
892  for (Attribute extent : extentsArray) {
893  IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
894  if (!attr)
895  return failure();
896  ints.push_back(attr.getInt());
897  }
898  Builder &builder = parser.getBuilder();
899  result.addAttribute("shape", builder.getIndexTensorAttr(ints));
900  Type resultTy;
901  if (parser.parseColonType(resultTy))
902  return failure();
903  result.types.push_back(resultTy);
904  return success();
905 }
906 
907 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
908 
909 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
910  MLIRContext *context) {
911  patterns.add<TensorCastConstShape>(context);
912 }
913 
914 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
915  MLIRContext *context, Optional<Location> location, ValueRange operands,
916  DictionaryAttr attributes, RegionRange regions,
917  SmallVectorImpl<Type> &inferredReturnTypes) {
918  Builder b(context);
919  auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
920  if (!shape)
921  return emitOptionalError(location, "missing shape attribute");
922  inferredReturnTypes.assign({RankedTensorType::get(
923  {static_cast<int64_t>(shape.size())}, b.getIndexType())});
924  return success();
925 }
926 
927 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
928  TypeRange r) {
929  if (l.size() != 1 || r.size() != 1)
930  return false;
931 
932  Type lhs = l.front();
933  Type rhs = r.front();
934 
935  if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
936  // Shape type is compatible with all other valid return types.
937  return true;
938  return lhs == rhs;
939 }
940 
941 //===----------------------------------------------------------------------===//
942 // CstrBroadcastableOp
943 //===----------------------------------------------------------------------===//
944 
945 void CstrBroadcastableOp::getCanonicalizationPatterns(
946  RewritePatternSet &patterns, MLIRContext *context) {
947  // Canonicalization patterns have overlap with the considerations during
948  // folding in case additional shape information is inferred at some point that
949  // does not result in folding.
950  patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
951  CstrBroadcastableEqOps,
952  RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
953  RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
954 }
955 
956 // Return true if there is exactly one attribute not representing a scalar
957 // broadcast.
959  bool nonScalarSeen = false;
960  for (Attribute a : attributes) {
961  if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
962  if (nonScalarSeen)
963  return false;
964  nonScalarSeen = true;
965  }
966  }
967  return true;
968 }
969 
970 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
971  // No broadcasting is needed if all operands but one are scalar.
972  if (hasAtMostSingleNonScalar(operands))
973  return BoolAttr::get(getContext(), true);
974 
975  if ([&] {
977  for (const auto &operand : operands) {
978  if (!operand)
979  return false;
980  extents.push_back(llvm::to_vector<6>(
981  operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
982  }
984  }())
985  return BoolAttr::get(getContext(), true);
986 
987  // Lastly, see if folding can be completed based on what constraints are known
988  // on the input shapes.
989  if ([&] {
991  for (auto shapeValue : getShapes()) {
992  extents.emplace_back();
993  if (failed(getShapeVec(shapeValue, extents.back())))
994  return false;
995  }
997  }())
998  return BoolAttr::get(getContext(), true);
999 
1000  // Because a failing witness result here represents an eventual assertion
1001  // failure, we do not replace it with a constant witness.
1002  return nullptr;
1003 }
1004 
1006  // Ensure that CstrBroadcastableOp contains at least two operands
1007  if (getNumOperands() < 2)
1008  return emitOpError("required at least 2 input shapes");
1009  return success();
1010 }
1011 
1012 //===----------------------------------------------------------------------===//
1013 // CstrEqOp
1014 //===----------------------------------------------------------------------===//
1015 
1016 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1017  MLIRContext *context) {
1018  // If inputs are equal, return passing witness
1019  patterns.add<CstrEqEqOps>(context);
1020 }
1021 
1022 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
1023  if (llvm::all_of(operands,
1024  [&](Attribute a) { return a && a == operands[0]; }))
1025  return BoolAttr::get(getContext(), true);
1026 
1027  // Because a failing witness result here represents an eventual assertion
1028  // failure, we do not try to replace it with a constant witness. Similarly, we
1029  // cannot if there are any non-const inputs.
1030  return nullptr;
1031 }
1032 
1033 //===----------------------------------------------------------------------===//
1034 // ConstSizeOp
1035 //===----------------------------------------------------------------------===//
1036 
1037 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1038  int64_t value) {
1039  build(builder, result, builder.getIndexAttr(value));
1040 }
1041 
1042 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
1043 
1044 void ConstSizeOp::getAsmResultNames(
1045  llvm::function_ref<void(Value, StringRef)> setNameFn) {
1046  SmallString<4> buffer;
1047  llvm::raw_svector_ostream os(buffer);
1048  os << "c" << getValue();
1049  setNameFn(getResult(), os.str());
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // ConstWitnessOp
1054 //===----------------------------------------------------------------------===//
1055 
1056 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
1057  return getPassingAttr();
1058 }
1059 
1060 //===----------------------------------------------------------------------===//
1061 // CstrRequireOp
1062 //===----------------------------------------------------------------------===//
1063 
1064 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
1065  return operands[0];
1066 }
1067 
1068 //===----------------------------------------------------------------------===//
1069 // DivOp
1070 //===----------------------------------------------------------------------===//
1071 
1072 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1073  auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1074  if (!lhs)
1075  return nullptr;
1076  auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1077  if (!rhs)
1078  return nullptr;
1079 
1080  // Division in APInt does not follow floor(lhs, rhs) when the result is
1081  // negative. Rather, APInt rounds toward zero.
1082  APInt quotient, remainder;
1083  APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1084  if (quotient.isNegative() && !remainder.isNullValue()) {
1085  quotient -= 1;
1086  }
1087 
1088  Type indexTy = IndexType::get(getContext());
1089  return IntegerAttr::get(indexTy, quotient);
1090 }
1091 
1092 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1093  MLIRContext *context, Optional<Location> location, ValueRange operands,
1094  DictionaryAttr attributes, RegionRange regions,
1095  SmallVectorImpl<Type> &inferredReturnTypes) {
1096  if (operands[0].getType().isa<SizeType>() ||
1097  operands[1].getType().isa<SizeType>())
1098  inferredReturnTypes.assign({SizeType::get(context)});
1099  else
1100  inferredReturnTypes.assign({IndexType::get(context)});
1101  return success();
1102 }
1103 
1104 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1105  // SizeType is compatible with IndexType.
1106  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1107 }
1108 
1109 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // ShapeEqOp
1113 //===----------------------------------------------------------------------===//
1114 
1115 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1116  bool allSame = true;
1117  if (!operands.empty() && !operands[0])
1118  return {};
1119  for (Attribute operand : operands.drop_front(1)) {
1120  if (!operand)
1121  return {};
1122  allSame = allSame && operand == operands[0];
1123  }
1124  return BoolAttr::get(getContext(), allSame);
1125 }
1126 
1127 //===----------------------------------------------------------------------===//
1128 // IndexToSizeOp
1129 //===----------------------------------------------------------------------===//
1130 
1131 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1132  // Constant values of both types, `shape.size` and `index`, are represented as
1133  // `IntegerAttr`s which makes constant folding simple.
1134  if (Attribute arg = operands[0])
1135  return arg;
1136  return {};
1137 }
1138 
1139 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1140  MLIRContext *context) {
1141  patterns.add<SizeToIndexToSizeCanonicalization>(context);
1142 }
1143 
1144 //===----------------------------------------------------------------------===//
1145 // FromExtentsOp
1146 //===----------------------------------------------------------------------===//
1147 
1148 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1149  if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1150  return nullptr;
1151  SmallVector<int64_t, 6> extents;
1152  for (auto attr : operands)
1153  extents.push_back(attr.cast<IntegerAttr>().getInt());
1154  Builder builder(getContext());
1155  return builder.getIndexTensorAttr(extents);
1156 }
1157 
1158 //===----------------------------------------------------------------------===//
1159 // FunctionLibraryOp
1160 //===----------------------------------------------------------------------===//
1161 
1162 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1163  StringRef name) {
1164  result.attributes.push_back(builder.getNamedAttr(
1166 }
1167 
1168 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1169  auto attr = getMapping()
1170  .get(op->getName().getIdentifier())
1171  .dyn_cast_or_null<FlatSymbolRefAttr>();
1172  if (!attr)
1173  return nullptr;
1174  return lookupSymbol<FuncOp>(attr);
1175 }
1176 
1177 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1178  OperationState &result) {
1179  // Parse the op name.
1180  StringAttr nameAttr;
1181  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1182  result.attributes))
1183  return failure();
1184 
1185  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1186  return failure();
1187 
1188  auto *bodyRegion = result.addRegion();
1189  if (parser.parseRegion(*bodyRegion))
1190  return failure();
1191 
1192  if (parser.parseKeyword("mapping"))
1193  return failure();
1194 
1195  DictionaryAttr mappingAttr;
1196  if (parser.parseAttribute(mappingAttr,
1197  parser.getBuilder().getType<NoneType>(), "mapping",
1198  result.attributes))
1199  return failure();
1200  return success();
1201 }
1202 
1204  p << ' ';
1205  p.printSymbolName(getName());
1207  (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1208  p << ' ';
1209  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1210  /*printBlockTerminators=*/false);
1211  p << " mapping ";
1212  p.printAttributeWithoutType(getMappingAttr());
1213 }
1214 
1215 //===----------------------------------------------------------------------===//
1216 // FuncOp
1217 //===----------------------------------------------------------------------===//
1218 
1219 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1220  auto buildFuncType =
1221  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1223  std::string &) { return builder.getFunctionType(argTypes, results); };
1224 
1226  parser, result, /*allowVariadic=*/false, buildFuncType);
1227 }
1228 
1229 void FuncOp::print(OpAsmPrinter &p) {
1230  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
1231 }
1232 
1233 //===----------------------------------------------------------------------===//
1234 // GetExtentOp
1235 //===----------------------------------------------------------------------===//
1236 
1237 Optional<int64_t> GetExtentOp::getConstantDim() {
1238  if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1239  return constSizeOp.getValue().getLimitedValue();
1240  if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1241  return constantOp.getValue().cast<IntegerAttr>().getInt();
1242  return llvm::None;
1243 }
1244 
1245 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1246  auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1247  if (!elements)
1248  return nullptr;
1249  Optional<int64_t> dim = getConstantDim();
1250  if (!dim.hasValue())
1251  return nullptr;
1252  if (dim.getValue() >= elements.getNumElements())
1253  return nullptr;
1254  return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
1255 }
1256 
1257 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1258  int64_t dim) {
1259  auto loc = result.location;
1260  auto dimAttr = builder.getIndexAttr(dim);
1261  if (shape.getType().isa<ShapeType>()) {
1262  Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1263  build(builder, result, builder.getType<SizeType>(), shape, dim);
1264  } else {
1265  Value dim =
1266  builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1267  build(builder, result, builder.getIndexType(), shape, dim);
1268  }
1269 }
1270 
1271 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1272  MLIRContext *context, Optional<Location> location, ValueRange operands,
1273  DictionaryAttr attributes, RegionRange regions,
1274  SmallVectorImpl<Type> &inferredReturnTypes) {
1275  inferredReturnTypes.assign({IndexType::get(context)});
1276  return success();
1277 }
1278 
1279 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1280  TypeRange r) {
1281  // SizeType is compatible with IndexType.
1282  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1283 }
1284 
1286 
1287 //===----------------------------------------------------------------------===//
1288 // IsBroadcastableOp
1289 //===----------------------------------------------------------------------===//
1290 
1291 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1292  MLIRContext *context) {
1293  patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1294 }
1295 
1296 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1297  // Can always broadcast fewer than two shapes.
1298  if (operands.size() < 2) {
1299  return BoolAttr::get(getContext(), true);
1300  }
1301 
1302  return nullptr;
1303 }
1304 
1305 //===----------------------------------------------------------------------===//
1306 // MeetOp
1307 //===----------------------------------------------------------------------===//
1308 
1309 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1310  MLIRContext *context, Optional<Location> location, ValueRange operands,
1311  DictionaryAttr attributes, RegionRange regions,
1312  SmallVectorImpl<Type> &inferredReturnTypes) {
1313  inferredReturnTypes.assign({operands[0].getType()});
1314  return success();
1315 }
1316 
1317 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1318  if (l.size() != 1 || r.size() != 1)
1319  return false;
1320  if (l == r)
1321  return true;
1322 
1323  Type lhs = l.front();
1324  Type rhs = r.front();
1325 
1326  if (lhs != rhs)
1327  return false;
1328 
1329  if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1330  return true;
1331 
1332  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1333  return true;
1334  return false;
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // RankOp
1339 //===----------------------------------------------------------------------===//
1340 
1341 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1342  auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1343  if (!shape)
1344  return {};
1345  int64_t rank = shape.getNumElements();
1346  Builder builder(getContext());
1347  return builder.getIndexAttr(rank);
1348 }
1349 
1350 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1351 /// Constant folding fails in cases where only the rank is constant, not the
1352 /// shape itself.
1353 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1354 ///
1355 /// Example:
1356 ///
1357 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1358 /// %rank = shape.rank %shape
1359 ///
1360 /// becomes
1361 ///
1362 /// %rank = shape.const_size 3
1363 
1364 namespace {
1365 struct RankShapeOfCanonicalizationPattern
1366  : public OpRewritePattern<shape::RankOp> {
1368 
1369  LogicalResult matchAndRewrite(shape::RankOp op,
1370  PatternRewriter &rewriter) const override {
1371  auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1372  if (!shapeOfOp)
1373  return failure();
1374  auto rankedTensorType =
1375  shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1376  if (!rankedTensorType)
1377  return failure();
1378  int64_t rank = rankedTensorType.getRank();
1379  if (op.getType().isa<IndexType>()) {
1380  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1381  rank);
1382  } else if (op.getType().isa<shape::SizeType>()) {
1383  rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1384  } else {
1385  return failure();
1386  }
1387  return success();
1388  }
1389 };
1390 } // namespace
1391 
1392 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1393  MLIRContext *context) {
1394  patterns.add<RankShapeOfCanonicalizationPattern>(context);
1395 }
1396 
1397 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1398  MLIRContext *context, Optional<Location> location, ValueRange operands,
1399  DictionaryAttr attributes, RegionRange regions,
1400  SmallVectorImpl<Type> &inferredReturnTypes) {
1401  if (operands[0].getType().isa<ShapeType>())
1402  inferredReturnTypes.assign({SizeType::get(context)});
1403  else
1404  inferredReturnTypes.assign({IndexType::get(context)});
1405  return success();
1406 }
1407 
1408 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1409  // SizeType is compatible with IndexType.
1410  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1411 }
1412 
1414 
1415 //===----------------------------------------------------------------------===//
1416 // NumElementsOp
1417 //===----------------------------------------------------------------------===//
1418 
1419 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1420 
1421  // Fold only when argument constant.
1422  Attribute shape = operands[0];
1423  if (!shape)
1424  return {};
1425 
1426  APInt product(64, 1);
1427  for (auto value : shape.cast<DenseIntElementsAttr>())
1428  product *= value;
1429  Builder builder(getContext());
1430  return builder.getIndexAttr(product.getLimitedValue());
1431 }
1432 
1433 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1434  MLIRContext *context, Optional<Location> location, ValueRange operands,
1435  DictionaryAttr attributes, RegionRange regions,
1436  SmallVectorImpl<Type> &inferredReturnTypes) {
1437  if (operands[0].getType().isa<ShapeType>())
1438  inferredReturnTypes.assign({SizeType::get(context)});
1439  else
1440  inferredReturnTypes.assign({IndexType::get(context)});
1441  return success();
1442 }
1443 
1444 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1445  TypeRange r) {
1446  // SizeType is compatible with IndexType.
1447  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1448 }
1449 
1451  return verifySizeOrIndexOp(*this);
1452 }
1453 
1454 //===----------------------------------------------------------------------===//
1455 // MaxOp
1456 //===----------------------------------------------------------------------===//
1457 
1458 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1459  // If operands are equal, just propagate one.
1460  if (getLhs() == getRhs())
1461  return getLhs();
1462  return nullptr;
1463 }
1464 
1465 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1466  MLIRContext *context, Optional<Location> location, ValueRange operands,
1467  DictionaryAttr attributes, RegionRange regions,
1468  SmallVectorImpl<Type> &inferredReturnTypes) {
1469  if (operands[0].getType() == operands[1].getType())
1470  inferredReturnTypes.assign({operands[0].getType()});
1471  else
1472  inferredReturnTypes.assign({SizeType::get(context)});
1473  return success();
1474 }
1475 
1476 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1477  if (l.size() != 1 || r.size() != 1)
1478  return false;
1479  if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1480  return true;
1481  if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1482  return true;
1483  return false;
1484 }
1485 
1486 //===----------------------------------------------------------------------===//
1487 // MinOp
1488 //===----------------------------------------------------------------------===//
1489 
1490 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1491  // If operands are equal, just propagate one.
1492  if (getLhs() == getRhs())
1493  return getLhs();
1494  return nullptr;
1495 }
1496 
1497 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1498  MLIRContext *context, Optional<Location> location, ValueRange operands,
1499  DictionaryAttr attributes, RegionRange regions,
1500  SmallVectorImpl<Type> &inferredReturnTypes) {
1501  if (operands[0].getType() == operands[1].getType())
1502  inferredReturnTypes.assign({operands[0].getType()});
1503  else
1504  inferredReturnTypes.assign({SizeType::get(context)});
1505  return success();
1506 }
1507 
1508 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1509  if (l.size() != 1 || r.size() != 1)
1510  return false;
1511  if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1512  return true;
1513  if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1514  return true;
1515  return false;
1516 }
1517 
1518 //===----------------------------------------------------------------------===//
1519 // MulOp
1520 //===----------------------------------------------------------------------===//
1521 
1522 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1523  auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1524  if (!lhs)
1525  return nullptr;
1526  auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1527  if (!rhs)
1528  return nullptr;
1529  APInt folded = lhs.getValue() * rhs.getValue();
1530  Type indexTy = IndexType::get(getContext());
1531  return IntegerAttr::get(indexTy, folded);
1532 }
1533 
1534 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1535  MLIRContext *context, Optional<Location> location, ValueRange operands,
1536  DictionaryAttr attributes, RegionRange regions,
1537  SmallVectorImpl<Type> &inferredReturnTypes) {
1538  if (operands[0].getType().isa<SizeType>() ||
1539  operands[1].getType().isa<SizeType>())
1540  inferredReturnTypes.assign({SizeType::get(context)});
1541  else
1542  inferredReturnTypes.assign({IndexType::get(context)});
1543  return success();
1544 }
1545 
1546 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1547  // SizeType is compatible with IndexType.
1548  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1549 }
1550 
1552 
1553 //===----------------------------------------------------------------------===//
1554 // ShapeOfOp
1555 //===----------------------------------------------------------------------===//
1556 
1557 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1558  auto type = getOperand().getType().dyn_cast<ShapedType>();
1559  if (!type || !type.hasStaticShape())
1560  return nullptr;
1561  Builder builder(getContext());
1562  return builder.getIndexTensorAttr(type.getShape());
1563 }
1564 
1565 namespace {
1566 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1568 
1569  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1570  PatternRewriter &rewriter) const override {
1571  if (!op.getArg().getType().isa<ShapedType>())
1572  return failure();
1573  if (op.getType().isa<ShapedType>())
1574  return failure();
1575 
1576  rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1577  op.getArg());
1578  return success();
1579  }
1580 };
1581 
1582 // Canonicalize
1583 // ```
1584 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1585 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1586 // ```
1587 // to
1588 // ```
1589 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1590 // ```
1591 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1593 
1594  LogicalResult matchAndRewrite(tensor::CastOp op,
1595  PatternRewriter &rewriter) const override {
1596  auto ty = op.getType().dyn_cast<RankedTensorType>();
1597  if (!ty || ty.getRank() != 1)
1598  return failure();
1599 
1600  auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1601  if (!shapeOfOp)
1602  return failure();
1603 
1604  // Argument type must be ranked and must not conflict.
1605  auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1606  if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1607  return failure();
1608 
1609  rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1610  return success();
1611  }
1612 };
1613 } // namespace
1614 
1615 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1616  MLIRContext *context) {
1617  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1618  ExtractFromShapeOfExtentTensor>(context);
1619 }
1620 
1621 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1622  MLIRContext *context, Optional<Location> location, ValueRange operands,
1623  DictionaryAttr attributes, RegionRange regions,
1624  SmallVectorImpl<Type> &inferredReturnTypes) {
1625  if (operands[0].getType().isa<ValueShapeType>())
1626  inferredReturnTypes.assign({ShapeType::get(context)});
1627  else {
1628  auto shapedTy = operands[0].getType().cast<ShapedType>();
1629  int64_t rank =
1630  shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1631  Type indexTy = IndexType::get(context);
1632  Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1633  inferredReturnTypes.assign({extentTensorTy});
1634  }
1635  return success();
1636 }
1637 
1638 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1639  if (l.size() != 1 || r.size() != 1)
1640  return false;
1641  if (l == r)
1642  return true;
1643 
1644  Type lhs = l.front();
1645  Type rhs = r.front();
1646 
1647  if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1648  return false;
1649 
1650  if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1651  // Shape type is compatible with all other valid return types.
1652  return true;
1653 
1654  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1655  return true;
1656  return false;
1657 }
1658 
1660  return verifyShapeOrExtentTensorOp(*this);
1661 }
1662 
1663 //===----------------------------------------------------------------------===//
1664 // SizeToIndexOp
1665 //===----------------------------------------------------------------------===//
1666 
1667 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1668  // Constant values of both types, `shape.size` and `index`, are represented as
1669  // `IntegerAttr`s which makes constant folding simple.
1670  if (Attribute arg = operands[0])
1671  return arg;
1672  return OpFoldResult();
1673 }
1674 
1675 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1676  MLIRContext *context) {
1677  patterns.add<IndexToSizeToIndexCanonicalization>(context);
1678 }
1679 
1680 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1681  if (inputs.size() != 1 || outputs.size() != 1)
1682  return false;
1683  return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
1684 }
1685 
1686 //===----------------------------------------------------------------------===//
1687 // YieldOp
1688 //===----------------------------------------------------------------------===//
1689 
1691  auto *parentOp = (*this)->getParentOp();
1692  auto results = parentOp->getResults();
1693  auto operands = getOperands();
1694 
1695  if (parentOp->getNumResults() != getNumOperands())
1696  return emitOpError() << "number of operands does not match number of "
1697  "results of its parent";
1698  for (auto e : llvm::zip(results, operands))
1699  if (std::get<0>(e).getType() != std::get<1>(e).getType())
1700  return emitOpError() << "types mismatch between yield op and its parent";
1701 
1702  return success();
1703 }
1704 
1705 //===----------------------------------------------------------------------===//
1706 // SplitAtOp
1707 //===----------------------------------------------------------------------===//
1708 
1709 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1710  SmallVectorImpl<OpFoldResult> &results) {
1711  if (!operands[0] || !operands[1])
1712  return failure();
1713  auto shapeVec = llvm::to_vector<6>(
1714  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1715  auto shape = llvm::makeArrayRef(shapeVec);
1716  auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1717  // Verify that the split point is in the correct range.
1718  // TODO: Constant fold to an "error".
1719  int64_t rank = shape.size();
1720  if (-rank > splitPoint || splitPoint > rank)
1721  return failure();
1722  if (splitPoint < 0)
1723  splitPoint += shape.size();
1724  Builder builder(operands[0].getContext());
1725  results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1726  results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1727  return success();
1728 }
1729 
1730 //===----------------------------------------------------------------------===//
1731 // ToExtentTensorOp
1732 //===----------------------------------------------------------------------===//
1733 
1734 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1735  if (!operands[0])
1736  return OpFoldResult();
1737  Builder builder(getContext());
1738  auto shape = llvm::to_vector<6>(
1739  operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1740  auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1741  builder.getIndexType());
1742  return DenseIntElementsAttr::get(type, shape);
1743 }
1744 
1745 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1746  if (inputs.size() != 1 || outputs.size() != 1)
1747  return false;
1748  if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
1749  if (!inputTensor.getElementType().isa<IndexType>() ||
1750  inputTensor.getRank() != 1)
1751  return false;
1752  } else if (!inputs[0].isa<ShapeType>()) {
1753  return false;
1754  }
1755 
1756  TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
1757  return outputTensor && outputTensor.getElementType().isa<IndexType>();
1758 }
1759 
1760 //===----------------------------------------------------------------------===//
1761 // ReduceOp
1762 //===----------------------------------------------------------------------===//
1763 
1764 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1765  ValueRange initVals) {
1766  result.addOperands(shape);
1767  result.addOperands(initVals);
1768 
1769  Region *bodyRegion = result.addRegion();
1770  bodyRegion->push_back(new Block);
1771  Block &bodyBlock = bodyRegion->front();
1772  bodyBlock.addArgument(builder.getIndexType(), result.location);
1773 
1774  Type elementType;
1775  if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1776  elementType = tensorType.getElementType();
1777  else
1778  elementType = SizeType::get(builder.getContext());
1779  bodyBlock.addArgument(elementType, shape.getLoc());
1780 
1781  for (Value initVal : initVals) {
1782  bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1783  result.addTypes(initVal.getType());
1784  }
1785 }
1786 
1788  // Verify block arg types.
1789  Block &block = getRegion().front();
1790 
1791  // The block takes index, extent, and aggregated values as arguments.
1792  auto blockArgsCount = getInitVals().size() + 2;
1793  if (block.getNumArguments() != blockArgsCount)
1794  return emitOpError() << "ReduceOp body is expected to have "
1795  << blockArgsCount << " arguments";
1796 
1797  // The first block argument is the index and must always be of type `index`.
1798  if (!block.getArgument(0).getType().isa<IndexType>())
1799  return emitOpError(
1800  "argument 0 of ReduceOp body is expected to be of IndexType");
1801 
1802  // The second block argument is the extent and must be of type `size` or
1803  // `index`, depending on whether the reduce operation is applied to a shape or
1804  // to an extent tensor.
1805  Type extentTy = block.getArgument(1).getType();
1806  if (getShape().getType().isa<ShapeType>()) {
1807  if (!extentTy.isa<SizeType>())
1808  return emitOpError("argument 1 of ReduceOp body is expected to be of "
1809  "SizeType if the ReduceOp operates on a ShapeType");
1810  } else {
1811  if (!extentTy.isa<IndexType>())
1812  return emitOpError(
1813  "argument 1 of ReduceOp body is expected to be of IndexType if the "
1814  "ReduceOp operates on an extent tensor");
1815  }
1816 
1817  for (const auto &type : llvm::enumerate(getInitVals()))
1818  if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1819  return emitOpError() << "type mismatch between argument "
1820  << type.index() + 2
1821  << " of ReduceOp body and initial value "
1822  << type.index();
1823  return success();
1824 }
1825 
1826 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1827  // Parse operands.
1829  Type shapeOrExtentTensorType;
1830  if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1832  parser.parseColonType(shapeOrExtentTensorType) ||
1833  parser.parseOptionalArrowTypeList(result.types))
1834  return failure();
1835 
1836  // Resolve operands.
1837  auto initVals = llvm::makeArrayRef(operands).drop_front();
1838  if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1839  result.operands) ||
1840  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1841  result.operands))
1842  return failure();
1843 
1844  // Parse the body.
1845  Region *body = result.addRegion();
1846  if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1847  return failure();
1848 
1849  // Parse attributes.
1850  if (parser.parseOptionalAttrDict(result.attributes))
1851  return failure();
1852 
1853  return success();
1854 }
1855 
1856 void ReduceOp::print(OpAsmPrinter &p) {
1857  p << '(' << getShape() << ", " << getInitVals()
1858  << ") : " << getShape().getType();
1859  p.printOptionalArrowTypeList(getResultTypes());
1860  p << ' ';
1861  p.printRegion(getRegion());
1862  p.printOptionalAttrDict((*this)->getAttrs());
1863 }
1864 
1865 #define GET_OP_CLASSES
1866 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1867 
1868 #define GET_TYPEDEF_CLASSES
1869 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within &#39;src&#39; can be inlined.
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition: Shape.cpp:83
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:55
MLIRContext * getContext() const
Definition: Builders.h:54
U cast() const
Definition: Attributes.h:130
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:380
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:158
Operation & back()
Definition: Block.h:143
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:141
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:302
virtual void printType(Type type)
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of &#39;symbolTableOp&#39;.
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:336
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:229
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block...
void push_back(Block *block)
Definition: Region.h:61
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
operand_type_range getOperandTypes()
Definition: Operation.h:321
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
This is the representation of an operand reference.
Operation & front()
Definition: Block.h:144
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:81
A named class for passing around the variadic flag.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:83
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
static constexpr const bool value
ParseResult parseSymbolName(StringAttr &result, StringRef attrName, NamedAttrList &attrs)
Parse an -identifier and store it (without the &#39;@&#39; symbol) in a string attribute named &#39;attrName&#39;...
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:144
LogicalResult emitOptionalError(Optional< Location > loc, Args &&... args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:486
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
bool getValue() const
Return the boolean value of this attribute.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition: Shape.cpp:96
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
void addOperands(ValueRange newOperands)
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
U dyn_cast() const
Definition: Types.h:256
unsigned getNumArguments()
Definition: Block.h:119
Attributes are known-constant values of operations.
Definition: Attributes.h:24
bool isExtentTensorType(Type)
Definition: Shape.cpp:43
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:526
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:43
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:32
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition: Traits.cpp:24
This is the interface that must be implemented by the dialects of operations to be inlined...
Definition: InliningUtils.h:40
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:338
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition: Shape.cpp:65
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic)
Printer implementation for function-like operations.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:310
void addTypes(ArrayRef< Type > newTypes)
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This represents an operation in an abstracted form, suitable for use with the builder APIs...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
Parens surrounding zero or more operands.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:259
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:77
auto getType() const
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if the attributes keyword is present.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition: Shape.cpp:48
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:369
Type front()
Return first type in the range.
Definition: TypeRange.h:158
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:286
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition: Shape.cpp:71
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:118
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
IndexType getIndexType()
Definition: Builders.cpp:48
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:59
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamicSize)
Alias type for extent tensors.
Definition: Shape.cpp:39
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
U dyn_cast() const
Definition: Attributes.h:124
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:333
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Type getElementType() const
Returns the element type of this tensor type.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with &#39;attribute...
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:328
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition: Shape.cpp:958
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:374
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, FuncTypeBuilder funcTypeBuilder)
Parser implementation for function-like operations.
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:328
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:67
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:383
static BoolAttr get(MLIRContext *context, bool value)
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
bool isa() const
Definition: Types.h:246
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
This class represents success/failure for parsing-like operations that find it important to chain tog...
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block &#39;source&#39; into the end of block &#39;dest&#39;.
This class helps build Operations.
Definition: Builders.h:184
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:197
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
result_type_range getResultTypes()
Definition: Operation.h:352
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
An attribute that represents a reference to a dense integer vector or tensor object.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.