MLIR  22.0.0git
Shape.cpp
Go to the documentation of this file.
1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 
17 #include "mlir/Dialect/Traits.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/SetOperations.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 using namespace mlir;
32 using namespace mlir::shape;
33 
34 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
35 
36 namespace {
37 #include "ShapeCanonicalization.inc"
38 } // namespace
39 
40 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
41  return RankedTensorType::get({rank}, IndexType::get(ctx));
42 }
43 
45  auto ranked = llvm::dyn_cast<RankedTensorType>(type);
46  return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
47 }
48 
49 LogicalResult shape::getShapeVec(Value input,
50  SmallVectorImpl<int64_t> &shapeValues) {
51  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
52  auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
53  if (!type.hasRank())
54  return failure();
55  llvm::append_range(shapeValues, type.getShape());
56  return success();
57  }
59  if (matchPattern(input, m_Constant(&attr))) {
60  llvm::append_range(shapeValues, attr.getValues<int64_t>());
61  return success();
62  }
63  return failure();
64 }
65 
66 static bool isErrorPropagationPossible(TypeRange operandTypes) {
67  return llvm::any_of(operandTypes,
68  llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
69 }
70 
71 static LogicalResult verifySizeOrIndexOp(Operation *op) {
72  assert(op != nullptr && op->getNumResults() == 1);
73  Type resultTy = op->getResultTypes().front();
75  if (!llvm::isa<SizeType>(resultTy))
76  return op->emitOpError()
77  << "if at least one of the operands can hold error values then "
78  "the result must be of type `size` to propagate them";
79  }
80  return success();
81 }
82 
83 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
84  assert(op != nullptr && op->getNumResults() == 1);
85  Type resultTy = op->getResultTypes().front();
87  if (!llvm::isa<ShapeType>(resultTy))
88  return op->emitOpError()
89  << "if at least one of the operands can hold error values then "
90  "the result must be of type `shape` to propagate them";
91  }
92  return success();
93 }
94 
95 template <typename... Ty>
96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97  return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
98 }
99 
100 template <typename... Ty, typename... ranges>
101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102  return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // InlinerInterface
107 //===----------------------------------------------------------------------===//
108 
109 namespace {
110 /// This class defines the interface for inlining shape dialect ops.
111 struct ShapeInlinerInterface : public DialectInlinerInterface {
113 
114  // Returns true if the given region 'src' can be inlined into the region
115  // 'dest' that is attached to an operation registered to the current dialect.
116  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117  IRMapping &) const final {
118  return true;
119  }
120 
121  // Returns true if the given operation 'op', that is registered to this
122  // dialect, can be inlined into the region 'dest' that is attached to an
123  // operation registered to the current dialect.
124  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125  IRMapping &) const final {
126  return true;
127  }
128 };
129 } // namespace
130 
131 void ShapeDialect::initialize() {
132  addOperations<
133 #define GET_OP_LIST
134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135  >();
136  addTypes<
137 #define GET_TYPEDEF_LIST
138 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
139  >();
140  addInterfaces<ShapeInlinerInterface>();
141  // Allow unknown operations during prototyping and testing. As the dialect is
142  // still evolving it makes it simple to start with an unregistered ops and
143  // try different variants before actually defining the op.
144  allowUnknownOperations();
145  declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
146  AssumingYieldOp>();
147 }
148 
150  Attribute value, Type type,
151  Location loc) {
152  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
153  return ub::PoisonOp::create(builder, loc, type, poison);
154 
155  if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
156  return ConstShapeOp::create(builder, loc, type,
157  llvm::cast<DenseIntElementsAttr>(value));
158  if (llvm::isa<SizeType>(type))
159  return ConstSizeOp::create(builder, loc, type,
160  llvm::cast<IntegerAttr>(value));
161  if (llvm::isa<WitnessType>(type))
162  return ConstWitnessOp::create(builder, loc, type,
163  llvm::cast<BoolAttr>(value));
164 
165  return arith::ConstantOp::materialize(builder, value, type, loc);
166 }
167 
168 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
169  NamedAttribute attribute) {
170  // Verify shape.lib attribute.
171  if (attribute.getName() == "shape.lib") {
172  if (!op->hasTrait<OpTrait::SymbolTable>())
173  return op->emitError(
174  "shape.lib attribute may only be on op implementing SymbolTable");
175 
176  if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
177  auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
178  if (!symbol)
179  return op->emitError("shape function library ")
180  << symbolRef << " not found";
181  return isa<shape::FunctionLibraryOp>(symbol)
182  ? success()
183  : op->emitError()
184  << symbolRef << " required to be shape function library";
185  }
186 
187  if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
188  // Verify all entries are function libraries and mappings in libraries
189  // refer to unique ops.
191  for (auto it : arr) {
192  if (!llvm::isa<SymbolRefAttr>(it))
193  return op->emitError(
194  "only SymbolRefAttr allowed in shape.lib attribute array");
195 
196  auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
197  SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
198  if (!shapeFnLib)
199  return op->emitError()
200  << it << " does not refer to FunctionLibraryOp";
201  for (auto mapping : shapeFnLib.getMapping()) {
202  if (!key.insert(mapping.getName()).second) {
203  return op->emitError("only one op to shape mapping allowed, found "
204  "multiple for `")
205  << mapping.getName() << "`";
206  }
207  }
208  }
209  return success();
210  }
211 
212  return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
213  "allowed as shape.lib attribute");
214  }
215  return success();
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // AnyOp
220 //===----------------------------------------------------------------------===//
221 
222 // TODO: Canonicalization should be implemented for shapes that can be
223 // determined through mixtures of the known dimensions of the inputs.
224 OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
225  // Only the last operand is checked because AnyOp is commutative.
226  if (adaptor.getInputs().back())
227  return adaptor.getInputs().back();
228 
229  return nullptr;
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // AssumingOp
234 //===----------------------------------------------------------------------===//
235 
236 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
237  result.regions.reserve(1);
238  Region *doRegion = result.addRegion();
239 
240  auto &builder = parser.getBuilder();
242  if (parser.parseOperand(cond) ||
243  parser.resolveOperand(cond, builder.getType<WitnessType>(),
244  result.operands))
245  return failure();
246 
247  // Parse optional results type list.
248  if (parser.parseOptionalArrowTypeList(result.types))
249  return failure();
250 
251  // Parse the region and add a terminator if elided.
252  if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
253  return failure();
254  AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
255 
256  // Parse the optional attribute list.
257  if (parser.parseOptionalAttrDict(result.attributes))
258  return failure();
259  return success();
260 }
261 
263  bool yieldsResults = !getResults().empty();
264 
265  p << " " << getWitness();
266  if (yieldsResults)
267  p << " -> (" << getResultTypes() << ")";
268  p << ' ';
269  p.printRegion(getDoRegion(),
270  /*printEntryBlockArgs=*/false,
271  /*printBlockTerminators=*/yieldsResults);
272  p.printOptionalAttrDict((*this)->getAttrs());
273 }
274 
275 namespace {
276 // Removes AssumingOp with a passing witness and inlines the region.
277 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
279 
280  LogicalResult matchAndRewrite(AssumingOp op,
281  PatternRewriter &rewriter) const override {
282  auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
283  if (!witness || !witness.getPassingAttr())
284  return failure();
285 
286  AssumingOp::inlineRegionIntoParent(op, rewriter);
287  return success();
288  }
289 };
290 
291 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
293 
294  LogicalResult matchAndRewrite(AssumingOp op,
295  PatternRewriter &rewriter) const override {
296  Block *body = op.getBody();
297  auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
298 
299  // Find used values.
300  SmallVector<Value, 4> newYieldOperands;
301  for (auto [opResult, yieldOperand] :
302  llvm::zip(op.getResults(), yieldOp.getOperands())) {
303  if (!opResult.getUses().empty()) {
304  newYieldOperands.push_back(yieldOperand);
305  }
306  }
307 
308  // Rewrite only if redundant results exist.
309  if (newYieldOperands.size() == yieldOp->getNumOperands())
310  return failure();
311 
312  // Replace yield op in the old assuming op's body and move the entire region
313  // to the new assuming op.
314  rewriter.setInsertionPointToEnd(body);
315  auto newYieldOp =
316  rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
317  rewriter.setInsertionPoint(op);
318  auto newOp = AssumingOp::create(
319  rewriter, op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
320  newOp.getDoRegion().takeBody(op.getDoRegion());
321 
322  // Use the new results to replace the previously used ones.
323  SmallVector<Value, 4> replacementValues;
324  auto src = newOp.getResults().begin();
325  for (auto it : op.getResults()) {
326  if (it.getUses().empty())
327  replacementValues.push_back(nullptr);
328  else
329  replacementValues.push_back(*src++);
330  }
331  rewriter.replaceOp(op, replacementValues);
332  return success();
333  }
334 };
335 } // namespace
336 
337 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
338  MLIRContext *context) {
339  patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
340 }
341 
342 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
343 void AssumingOp::getSuccessorRegions(
345  // AssumingOp has unconditional control flow into the region and back to the
346  // parent, so return the correct RegionSuccessor purely based on the index
347  // being None or 0.
348  if (!point.isParent()) {
349  regions.push_back(RegionSuccessor(getResults()));
350  return;
351  }
352 
353  regions.push_back(RegionSuccessor(&getDoRegion()));
354 }
355 
356 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
357  PatternRewriter &rewriter) {
358  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
359  auto *assumingBlock = op.getBody();
360  auto initPosition = rewriter.getInsertionPoint();
361  auto *blockAfterAssuming =
362  rewriter.splitBlock(blockBeforeAssuming, initPosition);
363 
364  // Remove the AssumingOp and AssumingYieldOp.
365  auto &yieldOp = assumingBlock->back();
366  rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
367  rewriter.replaceOp(op, yieldOp.getOperands());
368  rewriter.eraseOp(&yieldOp);
369 
370  // Merge blocks together as there was no branching behavior from the
371  // AssumingOp.
372  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
373  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
374 }
375 
376 void AssumingOp::build(
377  OpBuilder &builder, OperationState &result, Value witness,
379  OpBuilder::InsertionGuard g(builder);
380 
381  result.addOperands(witness);
382  Region *bodyRegion = result.addRegion();
383  builder.createBlock(bodyRegion);
384 
385  // Build body.
386  SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
387  AssumingYieldOp::create(builder, result.location, yieldValues);
388 
389  SmallVector<Type, 2> assumingTypes;
390  for (Value v : yieldValues)
391  assumingTypes.push_back(v.getType());
392  result.addTypes(assumingTypes);
393 }
394 
395 //===----------------------------------------------------------------------===//
396 // AddOp
397 //===----------------------------------------------------------------------===//
398 
399 LogicalResult mlir::shape::AddOp::inferReturnTypes(
400  MLIRContext *context, std::optional<Location> location,
401  AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
402  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
403  llvm::isa<SizeType>(adaptor.getRhs().getType()))
404  inferredReturnTypes.assign({SizeType::get(context)});
405  else
406  inferredReturnTypes.assign({IndexType::get(context)});
407  return success();
408 }
409 
410 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
411  // SizeType is compatible with IndexType.
412  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
413 }
414 
415 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
416  // add(x, 0) -> x
417  if (matchPattern(getRhs(), m_Zero()))
418  return getLhs();
419 
420  return constFoldBinaryOp<IntegerAttr>(
421  adaptor.getOperands(),
422  [](APInt a, const APInt &b) { return std::move(a) + b; });
423 }
424 
425 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
426 
427 //===----------------------------------------------------------------------===//
428 // AssumingAllOp
429 //===----------------------------------------------------------------------===//
430 
431 namespace {
432 
433 // Merge multiple `shape.assuming_all` operations together.
434 //
435 // %0 = shape.assuming_all %w0, %w1
436 // %1 = shape.assuming_all %w2, %0
437 //
438 // to:
439 //
440 // %0 = shape.assuming_all %w0, %w2, %w2
441 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
443 
444  LogicalResult matchAndRewrite(AssumingAllOp op,
445  PatternRewriter &rewriter) const override {
446  SmallVector<Value> operands;
447 
448  for (Value operand : op.getInputs()) {
449  if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
450  operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
451  else
452  operands.push_back(operand);
453  }
454 
455  // We didn't find any other `assuming_all` ops to merge with.
456  if (operands.size() == op.getNumOperands())
457  return failure();
458 
459  // Replace with a new `assuming_all` operation with merged constraints.
460  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
461  return success();
462  }
463 };
464 
465 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
466 // are subsumed by others.
467 //
468 // %0 = shape.cstr_broadcastable %shape0, %shape1
469 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
470 //
471 // %2 = shape.cstr_broadcastable %shape3, %shape4
472 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
473 //
474 // %4 = shape.assuming_all %0, %1, %2, %3
475 //
476 // to:
477 //
478 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
479 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
480 // %2 = shape.assuming_all %0, %1
481 //
482 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
483 // shapes [0, 1] are broadcastable too, and can be removed from the list of
484 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
485 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
486 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
488 
489  LogicalResult matchAndRewrite(AssumingAllOp op,
490  PatternRewriter &rewriter) const override {
491  // Collect all `CstrBroadcastableOp` operands first.
493  for (Value operand : op.getInputs()) {
494  // TODO: Apply this optimization if some of the witnesses are not
495  // produced by the `cstr_broadcastable`.
496  auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
497  if (!broadcastable)
498  return failure();
499 
500  operands.insert(broadcastable);
501  }
502 
503  // Skip trivial `assuming_all` operations.
504  if (operands.size() <= 1)
505  return failure();
506 
507  // Collect shapes checked by `cstr_broadcastable` operands.
509  for (auto cstr : operands) {
510  DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
511  shapes.emplace_back(cstr, std::move(shapesSet));
512  }
513 
514  // Sort by the number of shape operands (larger to smaller).
515  llvm::sort(shapes, [](auto a, auto b) {
516  return a.first.getNumOperands() > b.first.getNumOperands();
517  });
518 
519  // We start from the `cst_broadcastable` operations with largest number of
520  // shape operands, and remove redundant `cst_broadcastable` operations. We
521  // do this until we find a set of `cst_broadcastable` operations with
522  // non-overlapping constraints.
523  SmallVector<CstrBroadcastableOp> markedForErase;
524 
525  for (unsigned i = 0; i < shapes.size(); ++i) {
526  auto isSubset = [&](auto pair) {
527  return llvm::set_is_subset(pair.second, shapes[i].second);
528  };
529 
530  // Keep redundant `cstr_broadcastable` operations to be erased.
531  auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
532  for (auto *it0 = it; it0 < shapes.end(); ++it0)
533  markedForErase.push_back(it0->first);
534  shapes.erase(it, shapes.end());
535  }
536 
537  // We didn't find any operands that could be removed.
538  if (markedForErase.empty())
539  return failure();
540 
541  // Collect non-overlapping `cst_broadcastable` constraints.
542  SmallVector<Value> uniqueConstraints;
543  for (auto &shape : shapes)
544  uniqueConstraints.push_back(shape.first.getResult());
545 
546  // Replace with a new `assuming_all` operation ...
547  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
548 
549  // ... and maybe erase `cstr_broadcastable` ops without uses.
550  for (auto &op : markedForErase)
551  if (op->use_empty())
552  rewriter.eraseOp(op);
553 
554  return success();
555  }
556 };
557 
558 struct AssumingAllToCstrEqCanonicalization
559  : public OpRewritePattern<AssumingAllOp> {
561 
562  LogicalResult matchAndRewrite(AssumingAllOp op,
563  PatternRewriter &rewriter) const override {
564  SmallVector<Value, 8> shapes;
565  for (Value w : op.getInputs()) {
566  auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
567  if (!cstrEqOp)
568  return failure();
569  bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
570  return llvm::is_contained(shapes, s);
571  });
572  if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
573  return failure();
574  shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
575  }
576  rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
577  return success();
578  }
579 };
580 
581 template <typename OpTy>
582 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
584 
585  LogicalResult matchAndRewrite(OpTy op,
586  PatternRewriter &rewriter) const override {
587  // Find unique operands.
588  SetVector<Value> unique(op.operand_begin(), op.operand_end());
589 
590  // Reduce op to equivalent with unique operands.
591  if (unique.size() < op.getNumOperands()) {
592  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
593  unique.takeVector(), op->getAttrs());
594  return success();
595  }
596 
597  return failure();
598  }
599 };
600 } // namespace
601 
602 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
603  MLIRContext *context) {
604  patterns
605  .add<MergeAssumingAllOps, AssumingAllOneOp,
606  AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
607  RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
608 }
609 
610 OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
611  // Iterate in reverse to first handle all constant operands. They are
612  // guaranteed to be the tail of the inputs because this is commutative.
613  for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
614  Attribute a = adaptor.getInputs()[idx];
615  // Cannot fold if any inputs are not constant;
616  if (!a)
617  return nullptr;
618 
619  // We do not need to keep statically known values after handling them in
620  // this method.
621  getOperation()->eraseOperand(idx);
622 
623  // Always false if any input is statically known false
624  if (!llvm::cast<BoolAttr>(a).getValue())
625  return a;
626  }
627  // If this is reached, all inputs were statically known passing.
628  return BoolAttr::get(getContext(), true);
629 }
630 
631 LogicalResult AssumingAllOp::verify() {
632  // Ensure that AssumingAllOp contains at least one operand
633  if (getNumOperands() == 0)
634  return emitOpError("no operands specified");
635 
636  return success();
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // BroadcastOp
641 //===----------------------------------------------------------------------===//
642 
643 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
644  if (getShapes().size() == 1) {
645  // Otherwise, we need a cast which would be a canonicalization, not folding.
646  if (getShapes().front().getType() != getType())
647  return nullptr;
648  return getShapes().front();
649  }
650 
651  if (!adaptor.getShapes().front())
652  return nullptr;
653 
654  SmallVector<int64_t, 6> resultShape(
655  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
656  .getValues<int64_t>());
657 
658  for (auto next : adaptor.getShapes().drop_front()) {
659  if (!next)
660  return nullptr;
661  auto nextShape = llvm::to_vector<6>(
662  llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
663 
664  SmallVector<int64_t, 6> tmpShape;
665  // If the shapes are not compatible, we can't fold it.
666  // TODO: Fold to an "error".
667  if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
668  return nullptr;
669 
670  resultShape.clear();
671  std::copy(tmpShape.begin(), tmpShape.end(),
672  std::back_inserter(resultShape));
673  }
674 
675  Builder builder(getContext());
676  return builder.getIndexTensorAttr(resultShape);
677 }
678 
679 LogicalResult BroadcastOp::verify() {
680  return verifyShapeOrExtentTensorOp(*this);
681 }
682 
683 namespace {
684 template <typename OpTy>
685 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
687 
688  LogicalResult matchAndRewrite(OpTy op,
689  PatternRewriter &rewriter) const override {
690  auto isPotentiallyNonEmptyShape = [](Value shape) {
691  if (auto extentTensorTy =
692  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
693  if (extentTensorTy.getDimSize(0) == 0)
694  return false;
695  }
696  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
697  if (constShape.getShape().empty())
698  return false;
699  }
700  return true;
701  };
702  auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
703  isPotentiallyNonEmptyShape);
704 
705  // Replace the op with empty shape constant if all operants are reduced to
706  // be empty.
707  if (newOperands.empty()) {
708  rewriter.replaceOpWithNewOp<ConstShapeOp>(
709  op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
710  return success();
711  }
712 
713  // Reduce op to equivalent without empty shape operands.
714  if (newOperands.size() < op.getNumOperands()) {
715  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
716  op->getAttrs());
717  return success();
718  }
719 
720  return failure();
721  }
722 };
723 
724 struct BroadcastForwardSingleOperandPattern
725  : public OpRewritePattern<BroadcastOp> {
727 
728  LogicalResult matchAndRewrite(BroadcastOp op,
729  PatternRewriter &rewriter) const override {
730  if (op.getNumOperands() != 1)
731  return failure();
732  Value replacement = op.getShapes().front();
733 
734  // Insert cast if needed.
735  if (replacement.getType() != op.getType()) {
736  auto loc = op.getLoc();
737  if (llvm::isa<ShapeType>(op.getType())) {
738  replacement = FromExtentTensorOp::create(rewriter, loc, replacement);
739  } else {
740  assert(!llvm::isa<ShapeType>(op.getType()) &&
741  !llvm::isa<ShapeType>(replacement.getType()) &&
742  "expect extent tensor cast");
743  replacement =
744  tensor::CastOp::create(rewriter, loc, op.getType(), replacement);
745  }
746  }
747 
748  rewriter.replaceOp(op, replacement);
749  return success();
750  }
751 };
752 
753 struct BroadcastFoldConstantOperandsPattern
754  : public OpRewritePattern<BroadcastOp> {
756 
757  LogicalResult matchAndRewrite(BroadcastOp op,
758  PatternRewriter &rewriter) const override {
759  SmallVector<int64_t, 8> foldedConstantShape;
760  SmallVector<Value, 8> newShapeOperands;
761  for (Value shape : op.getShapes()) {
762  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
763  SmallVector<int64_t, 8> newFoldedConstantShape;
765  foldedConstantShape,
766  llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
767  newFoldedConstantShape)) {
768  foldedConstantShape = newFoldedConstantShape;
769  continue;
770  }
771  }
772  newShapeOperands.push_back(shape);
773  }
774 
775  // Need at least two constant operands to fold anything.
776  if (op.getNumOperands() - newShapeOperands.size() < 2)
777  return failure();
778 
779  auto foldedConstantOperandsTy = RankedTensorType::get(
780  {static_cast<int64_t>(foldedConstantShape.size())},
781  rewriter.getIndexType());
782  newShapeOperands.push_back(
783  ConstShapeOp::create(rewriter, op.getLoc(), foldedConstantOperandsTy,
784  rewriter.getIndexTensorAttr(foldedConstantShape)));
785  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
786  newShapeOperands);
787  return success();
788  }
789 };
790 
791 template <typename OpTy>
792 struct CanonicalizeCastExtentTensorOperandsPattern
793  : public OpRewritePattern<OpTy> {
795 
796  LogicalResult matchAndRewrite(OpTy op,
797  PatternRewriter &rewriter) const override {
798  // Canonicalize operands.
799  bool anyChange = false;
800  auto canonicalizeOperand = [&](Value operand) -> Value {
801  if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
802  // Only eliminate the cast if it holds no shape information.
803  bool isInformationLoosingCast =
804  llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
805  if (isInformationLoosingCast) {
806  anyChange = true;
807  return castOp.getSource();
808  }
809  }
810  return operand;
811  };
812  auto newOperands = llvm::to_vector<8>(
813  llvm::map_range(op.getOperands(), canonicalizeOperand));
814 
815  // Rewrite op if any change required.
816  if (!anyChange)
817  return failure();
818  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
819  return success();
820  }
821 };
822 
823 struct BroadcastConcretizeResultTypePattern
824  : public OpRewritePattern<BroadcastOp> {
826 
827  LogicalResult matchAndRewrite(BroadcastOp op,
828  PatternRewriter &rewriter) const override {
829  // Only concretize dynamic extent tensor result types.
830  auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
831  if (!resultTy || !resultTy.isDynamicDim(0))
832  return failure();
833 
834  // Infer resulting shape rank if possible.
835  int64_t maxRank = 0;
836  for (Value shape : op.getShapes()) {
837  if (auto extentTensorTy =
838  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
839  // Cannot infer resulting shape rank if any operand is dynamically
840  // ranked.
841  if (extentTensorTy.isDynamicDim(0))
842  return failure();
843  maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
844  }
845  }
846 
847  auto newOp = BroadcastOp::create(rewriter, op.getLoc(),
848  getExtentTensorType(getContext(), maxRank),
849  op.getShapes());
850  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
851  return success();
852  }
853 };
854 } // namespace
855 
856 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
857  MLIRContext *context) {
858  patterns.add<BroadcastConcretizeResultTypePattern,
859  BroadcastFoldConstantOperandsPattern,
860  BroadcastForwardSingleOperandPattern,
861  CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
862  RemoveDuplicateOperandsPattern<BroadcastOp>,
863  RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
864 }
865 
866 //===----------------------------------------------------------------------===//
867 // ConcatOp
868 //===----------------------------------------------------------------------===//
869 
870 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
871  if (!adaptor.getLhs() || !adaptor.getRhs())
872  return nullptr;
873  auto lhsShape = llvm::to_vector<6>(
874  llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
875  auto rhsShape = llvm::to_vector<6>(
876  llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
877  SmallVector<int64_t, 6> resultShape;
878  resultShape.append(lhsShape.begin(), lhsShape.end());
879  resultShape.append(rhsShape.begin(), rhsShape.end());
880  Builder builder(getContext());
881  return builder.getIndexTensorAttr(resultShape);
882 }
883 
884 //===----------------------------------------------------------------------===//
885 // ConstShapeOp
886 //===----------------------------------------------------------------------===//
887 
889  p << " ";
890  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
891  p << "[";
892  interleaveComma(getShape().getValues<int64_t>(), p);
893  p << "] : ";
894  p.printType(getType());
895 }
896 
897 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
898  if (parser.parseOptionalAttrDict(result.attributes))
899  return failure();
900  // We piggy-back on ArrayAttr parsing, though we don't internally store the
901  // shape as an ArrayAttr.
902  // TODO: Implement custom parser and maybe make syntax a bit more concise.
903  Attribute extentsRaw;
904  NamedAttrList dummy;
905  if (parser.parseAttribute(extentsRaw, "dummy", dummy))
906  return failure();
907  auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
908  if (!extentsArray)
909  return failure();
911  for (Attribute extent : extentsArray) {
912  IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
913  if (!attr)
914  return failure();
915  ints.push_back(attr.getInt());
916  }
917  Builder &builder = parser.getBuilder();
918  result.addAttribute("shape", builder.getIndexTensorAttr(ints));
919  Type resultTy;
920  if (parser.parseColonType(resultTy))
921  return failure();
922  result.types.push_back(resultTy);
923  return success();
924 }
925 
926 OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
927 
928 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
929  MLIRContext *context) {
930  patterns.add<TensorCastConstShape>(context);
931 }
932 
933 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
934  MLIRContext *context, std::optional<Location> location,
935  ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
936  Builder b(context);
937  const Properties prop = adaptor.getProperties();
938  inferredReturnTypes.assign({RankedTensorType::get(
939  {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
940  return success();
941 }
942 
943 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
944  TypeRange r) {
945  if (l.size() != 1 || r.size() != 1)
946  return false;
947 
948  Type lhs = l.front();
949  Type rhs = r.front();
950 
951  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
952  // Shape type is compatible with all other valid return types.
953  return true;
954  return lhs == rhs;
955 }
956 
957 //===----------------------------------------------------------------------===//
958 // CstrBroadcastableOp
959 //===----------------------------------------------------------------------===//
960 
961 void CstrBroadcastableOp::getCanonicalizationPatterns(
963  // Canonicalization patterns have overlap with the considerations during
964  // folding in case additional shape information is inferred at some point that
965  // does not result in folding.
966  patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
967  CstrBroadcastableEqOps,
968  RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
969  RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
970 }
971 
972 // Return true if there is exactly one attribute not representing a scalar
973 // broadcast.
975  bool nonScalarSeen = false;
976  for (Attribute a : attributes) {
977  if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
978  if (nonScalarSeen)
979  return false;
980  nonScalarSeen = true;
981  }
982  }
983  return true;
984 }
985 
986 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
987  // No broadcasting is needed if all operands but one are scalar.
988  if (hasAtMostSingleNonScalar(adaptor.getShapes()))
989  return BoolAttr::get(getContext(), true);
990 
991  if ([&] {
993  for (const auto &operand : adaptor.getShapes()) {
994  if (!operand)
995  return false;
996  extents.push_back(llvm::to_vector<6>(
997  llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
998  }
1000  }())
1001  return BoolAttr::get(getContext(), true);
1002 
1003  // Lastly, see if folding can be completed based on what constraints are known
1004  // on the input shapes.
1005  if ([&] {
1007  for (auto shapeValue : getShapes()) {
1008  extents.emplace_back();
1009  if (failed(getShapeVec(shapeValue, extents.back())))
1010  return false;
1011  }
1013  }())
1014  return BoolAttr::get(getContext(), true);
1015 
1016  // Because a failing witness result here represents an eventual assertion
1017  // failure, we do not replace it with a constant witness.
1018  return nullptr;
1019 }
1020 
1021 LogicalResult CstrBroadcastableOp::verify() {
1022  // Ensure that CstrBroadcastableOp contains at least two operands
1023  if (getNumOperands() < 2)
1024  return emitOpError("required at least 2 input shapes");
1025  return success();
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // CstrEqOp
1030 //===----------------------------------------------------------------------===//
1031 
1032 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1033  MLIRContext *context) {
1034  // If inputs are equal, return passing witness
1035  patterns.add<CstrEqEqOps>(context);
1036 }
1037 
1038 OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1039  if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1040  return a && a == adaptor.getShapes().front();
1041  }))
1042  return BoolAttr::get(getContext(), true);
1043 
1044  // Because a failing witness result here represents an eventual assertion
1045  // failure, we do not try to replace it with a constant witness. Similarly, we
1046  // cannot if there are any non-const inputs.
1047  return nullptr;
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // ConstSizeOp
1052 //===----------------------------------------------------------------------===//
1053 
1054 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1055  int64_t value) {
1056  build(builder, result, builder.getIndexAttr(value));
1057 }
1058 
1059 OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1060 
1061 void ConstSizeOp::getAsmResultNames(
1062  llvm::function_ref<void(Value, StringRef)> setNameFn) {
1063  SmallString<4> buffer;
1064  llvm::raw_svector_ostream os(buffer);
1065  os << "c" << getValue();
1066  setNameFn(getResult(), os.str());
1067 }
1068 
1069 //===----------------------------------------------------------------------===//
1070 // ConstWitnessOp
1071 //===----------------------------------------------------------------------===//
1072 
1073 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1074 
1075 //===----------------------------------------------------------------------===//
1076 // CstrRequireOp
1077 //===----------------------------------------------------------------------===//
1078 
1079 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1080  return adaptor.getPred();
1081 }
1082 
1083 //===----------------------------------------------------------------------===//
1084 // DimOp
1085 //===----------------------------------------------------------------------===//
1086 
1087 std::optional<int64_t> DimOp::getConstantIndex() {
1088  if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1089  return constSizeOp.getValue().getLimitedValue();
1090  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1091  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1092  return std::nullopt;
1093 }
1094 
1095 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1096  Type valType = getValue().getType();
1097  auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1098  if (!valShapedType || !valShapedType.hasRank())
1099  return nullptr;
1100  std::optional<int64_t> index = getConstantIndex();
1101  if (!index.has_value())
1102  return nullptr;
1103  if (index.value() < 0 || index.value() >= valShapedType.getRank())
1104  return nullptr;
1105  auto extent = valShapedType.getDimSize(*index);
1106  if (ShapedType::isDynamic(extent))
1107  return nullptr;
1108  return IntegerAttr::get(IndexType::get(getContext()), extent);
1109 }
1110 
1111 LogicalResult mlir::shape::DimOp::inferReturnTypes(
1112  MLIRContext *context, std::optional<Location> location,
1113  DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1114  inferredReturnTypes.assign({adaptor.getIndex().getType()});
1115  return success();
1116 }
1117 
1118 bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1119  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1120 }
1121 
1122 //===----------------------------------------------------------------------===//
1123 // DivOp
1124 //===----------------------------------------------------------------------===//
1125 
1126 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1127  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1128  if (!lhs)
1129  return nullptr;
1130  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1131  if (!rhs || rhs.getValue().isZero())
1132  return nullptr;
1133 
1134  // Division in APInt does not follow floor(lhs, rhs) when the result is
1135  // negative. Rather, APInt rounds toward zero.
1136  APInt quotient, remainder;
1137  APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1138  if (quotient.isNegative() && !remainder.isZero()) {
1139  quotient -= 1;
1140  }
1141 
1142  Type indexTy = IndexType::get(getContext());
1143  return IntegerAttr::get(indexTy, quotient);
1144 }
1145 
1146 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1147  MLIRContext *context, std::optional<Location> location,
1148  DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1149  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1150  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1151  inferredReturnTypes.assign({SizeType::get(context)});
1152  else
1153  inferredReturnTypes.assign({IndexType::get(context)});
1154  return success();
1155 }
1156 
1157 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1158  // SizeType is compatible with IndexType.
1159  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1160 }
1161 
1162 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1163 
1164 //===----------------------------------------------------------------------===//
1165 // ShapeEqOp
1166 //===----------------------------------------------------------------------===//
1167 
1168 OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1169  bool allSame = true;
1170  if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1171  return {};
1172  for (Attribute operand : adaptor.getShapes().drop_front()) {
1173  if (!operand)
1174  return {};
1175  allSame = allSame && operand == adaptor.getShapes().front();
1176  }
1177  return BoolAttr::get(getContext(), allSame);
1178 }
1179 
1180 //===----------------------------------------------------------------------===//
1181 // IndexToSizeOp
1182 //===----------------------------------------------------------------------===//
1183 
1184 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1185  // Constant values of both types, `shape.size` and `index`, are represented as
1186  // `IntegerAttr`s which makes constant folding simple.
1187  if (Attribute arg = adaptor.getArg())
1188  return arg;
1189  return {};
1190 }
1191 
1192 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1193  MLIRContext *context) {
1194  patterns.add<SizeToIndexToSizeCanonicalization>(context);
1195 }
1196 
1197 //===----------------------------------------------------------------------===//
1198 // FromExtentsOp
1199 //===----------------------------------------------------------------------===//
1200 
1201 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1202  if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
1203  return nullptr;
1204  SmallVector<int64_t, 6> extents;
1205  for (auto attr : adaptor.getExtents())
1206  extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1207  Builder builder(getContext());
1208  return builder.getIndexTensorAttr(extents);
1209 }
1210 
1211 //===----------------------------------------------------------------------===//
1212 // FunctionLibraryOp
1213 //===----------------------------------------------------------------------===//
1214 
1215 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1216  StringRef name) {
1217  result.attributes.push_back(builder.getNamedAttr(
1219 }
1220 
1221 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1222  auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1223  getMapping().get(op->getName().getIdentifier()));
1224  if (!attr)
1225  return nullptr;
1226  return lookupSymbol<FuncOp>(attr);
1227 }
1228 
1229 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1230  OperationState &result) {
1231  // Parse the op name.
1232  StringAttr nameAttr;
1233  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1234  result.attributes))
1235  return failure();
1236 
1237  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1238  return failure();
1239 
1240  auto *bodyRegion = result.addRegion();
1241  if (parser.parseRegion(*bodyRegion))
1242  return failure();
1243 
1244  if (parser.parseKeyword("mapping"))
1245  return failure();
1246 
1247  DictionaryAttr mappingAttr;
1248  if (parser.parseAttribute(mappingAttr,
1249  parser.getBuilder().getType<NoneType>(), "mapping",
1250  result.attributes))
1251  return failure();
1252  return success();
1253 }
1254 
1256  p << ' ';
1257  p.printSymbolName(getName());
1259  (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1260  p << ' ';
1261  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1262  /*printBlockTerminators=*/false);
1263  p << " mapping ";
1264  p.printAttributeWithoutType(getMappingAttr());
1265 }
1266 
1267 //===----------------------------------------------------------------------===//
1268 // FuncOp
1269 //===----------------------------------------------------------------------===//
1270 
1271 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1272  ArrayRef<NamedAttribute> attrs) {
1273  OpBuilder builder(location->getContext());
1274  OperationState state(location, getOperationName());
1275  FuncOp::build(builder, state, name, type, attrs);
1276  return cast<FuncOp>(Operation::create(state));
1277 }
1278 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1280  SmallVector<NamedAttribute, 8> attrRef(attrs);
1281  return create(location, name, type, llvm::ArrayRef(attrRef));
1282 }
1283 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1285  ArrayRef<DictionaryAttr> argAttrs) {
1286  FuncOp func = create(location, name, type, attrs);
1287  func.setAllArgAttrs(argAttrs);
1288  return func;
1289 }
1290 
1291 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1292  FunctionType type, ArrayRef<NamedAttribute> attrs,
1293  ArrayRef<DictionaryAttr> argAttrs) {
1294  state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1295  builder.getStringAttr(name));
1296  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1297  TypeAttr::get(type));
1298  state.attributes.append(attrs.begin(), attrs.end());
1299  state.addRegion();
1300 
1301  if (argAttrs.empty())
1302  return;
1303  assert(type.getNumInputs() == argAttrs.size());
1305  builder, state, argAttrs, /*resultAttrs=*/{},
1306  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1307 }
1308 
1309 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1310  auto buildFuncType =
1311  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1313  std::string &) { return builder.getFunctionType(argTypes, results); };
1314 
1316  parser, result, /*allowVariadic=*/false,
1317  getFunctionTypeAttrName(result.name), buildFuncType,
1318  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1319 }
1320 
1321 void FuncOp::print(OpAsmPrinter &p) {
1323  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1324  getArgAttrsAttrName(), getResAttrsAttrName());
1325 }
1326 
1327 //===----------------------------------------------------------------------===//
1328 // GetExtentOp
1329 //===----------------------------------------------------------------------===//
1330 
1331 std::optional<int64_t> GetExtentOp::getConstantDim() {
1332  if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1333  return constSizeOp.getValue().getLimitedValue();
1334  if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1335  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1336  return std::nullopt;
1337 }
1338 
1339 OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1340  auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1341  if (!elements)
1342  return nullptr;
1343  std::optional<int64_t> dim = getConstantDim();
1344  if (!dim.has_value())
1345  return nullptr;
1346  if (dim.value() >= elements.getNumElements())
1347  return nullptr;
1348  return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1349 }
1350 
1351 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1352  int64_t dim) {
1353  auto loc = result.location;
1354  auto dimAttr = builder.getIndexAttr(dim);
1355  if (llvm::isa<ShapeType>(shape.getType())) {
1356  Value dim = ConstSizeOp::create(builder, loc, dimAttr);
1357  build(builder, result, builder.getType<SizeType>(), shape, dim);
1358  } else {
1359  Value dim = arith::ConstantOp::create(builder, loc, builder.getIndexType(),
1360  dimAttr);
1361  build(builder, result, builder.getIndexType(), shape, dim);
1362  }
1363 }
1364 
1365 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1366  MLIRContext *context, std::optional<Location> location,
1367  GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1368  inferredReturnTypes.assign({IndexType::get(context)});
1369  return success();
1370 }
1371 
1372 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1373  TypeRange r) {
1374  // SizeType is compatible with IndexType.
1375  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1376 }
1377 
1378 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1379 
1380 //===----------------------------------------------------------------------===//
1381 // IsBroadcastableOp
1382 //===----------------------------------------------------------------------===//
1383 
1384 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1385  MLIRContext *context) {
1386  patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1387 }
1388 
1389 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1390  // Can always broadcast fewer than two shapes.
1391  if (adaptor.getShapes().size() < 2) {
1392  return BoolAttr::get(getContext(), true);
1393  }
1394 
1395  return nullptr;
1396 }
1397 
1398 //===----------------------------------------------------------------------===//
1399 // MeetOp
1400 //===----------------------------------------------------------------------===//
1401 
1402 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1403  MLIRContext *context, std::optional<Location> location,
1404  MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1405  if (adaptor.getOperands().empty())
1406  return failure();
1407 
1408  auto isShapeType = [](Type arg) {
1409  if (llvm::isa<ShapeType>(arg))
1410  return true;
1411  return isExtentTensorType(arg);
1412  };
1413 
1414  ValueRange::type_range types = adaptor.getOperands().getTypes();
1415  Type acc = types.front();
1416  for (auto t : drop_begin(types)) {
1417  Type l = acc, r = t;
1418  if (!llvm::isa<ShapeType, SizeType>(l))
1419  std::swap(l, r);
1420 
1421  // Handle sizes, propagate error type if present.
1422  if (llvm::isa<SizeType>(l)) {
1423  if (llvm::isa<SizeType, IndexType>(r))
1424  acc = l;
1425  else
1426  return emitOptionalError(location, "requires all sizes or shapes");
1427  } else if (llvm::isa<IndexType>(l)) {
1428  if (llvm::isa<IndexType>(r))
1429  acc = r;
1430  else
1431  return emitOptionalError(location, "requires all sizes or shapes");
1432  } else if (llvm::isa<ShapeType>(l)) {
1433  // Handle shapes, propagate error type if present.
1434  if (isShapeType(r))
1435  acc = l;
1436  else
1437  return emitOptionalError(location, "requires all sizes or shapes");
1438  } else if (isExtentTensorType(l)) {
1439  auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1440  auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1441  if (ShapedType::isDynamic(rank1))
1442  acc = l;
1443  else if (ShapedType::isDynamic(rank2))
1444  acc = r;
1445  else if (rank1 != rank2)
1446  return emitOptionalError(location, "unequal shape cardinality");
1447  else
1448  acc = l;
1449  }
1450  }
1451  inferredReturnTypes.assign({acc});
1452  return success();
1453 }
1454 
1455 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1456  if (l.size() != 1 || r.size() != 1)
1457  return false;
1458  if (l == r)
1459  return true;
1460 
1461  Type lhs = l.front();
1462  Type rhs = r.front();
1463 
1464  if (!llvm::isa<ShapeType, SizeType>(lhs))
1465  std::swap(lhs, rhs);
1466 
1467  if (llvm::isa<SizeType>(lhs))
1468  return llvm::isa<SizeType, IndexType>(rhs);
1469  if (llvm::isa<ShapeType>(lhs))
1470  return llvm::isa<ShapeType, TensorType>(rhs);
1471 
1472  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1473  return true;
1474  return false;
1475 }
1476 
1477 //===----------------------------------------------------------------------===//
1478 // RankOp
1479 //===----------------------------------------------------------------------===//
1480 
1481 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1482  auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1483  if (!shape)
1484  return {};
1485  int64_t rank = shape.getNumElements();
1486  Builder builder(getContext());
1487  return builder.getIndexAttr(rank);
1488 }
1489 
1490 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1491 /// Constant folding fails in cases where only the rank is constant, not the
1492 /// shape itself.
1493 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1494 ///
1495 /// Example:
1496 ///
1497 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1498 /// %rank = shape.rank %shape
1499 ///
1500 /// becomes
1501 ///
1502 /// %rank = shape.const_size 3
1503 
1504 namespace {
1505 struct RankShapeOfCanonicalizationPattern
1506  : public OpRewritePattern<shape::RankOp> {
1508 
1509  LogicalResult matchAndRewrite(shape::RankOp op,
1510  PatternRewriter &rewriter) const override {
1511  auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1512  if (!shapeOfOp)
1513  return failure();
1514  auto rankedTensorType =
1515  llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1516  if (!rankedTensorType)
1517  return failure();
1518  int64_t rank = rankedTensorType.getRank();
1519  if (llvm::isa<IndexType>(op.getType())) {
1520  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1521  rank);
1522  } else if (llvm::isa<shape::SizeType>(op.getType())) {
1523  rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1524  } else {
1525  return failure();
1526  }
1527  return success();
1528  }
1529 };
1530 } // namespace
1531 
1532 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1533  MLIRContext *context) {
1534  patterns.add<RankShapeOfCanonicalizationPattern>(context);
1535 }
1536 
1537 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1538  MLIRContext *context, std::optional<Location> location,
1539  RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1540  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1541  inferredReturnTypes.assign({SizeType::get(context)});
1542  else
1543  inferredReturnTypes.assign({IndexType::get(context)});
1544  return success();
1545 }
1546 
1547 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1548  // SizeType is compatible with IndexType.
1549  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1550 }
1551 
1552 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1553 
1554 //===----------------------------------------------------------------------===//
1555 // NumElementsOp
1556 //===----------------------------------------------------------------------===//
1557 
1558 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1559 
1560  // Fold only when argument constant.
1561  Attribute shape = adaptor.getShape();
1562  if (!shape)
1563  return {};
1564 
1565  APInt product(64, 1);
1566  for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1567  product *= value;
1568  Builder builder(getContext());
1569  return builder.getIndexAttr(product.getLimitedValue());
1570 }
1571 
1572 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1573  MLIRContext *context, std::optional<Location> location,
1574  NumElementsOp::Adaptor adaptor,
1575  SmallVectorImpl<Type> &inferredReturnTypes) {
1576  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1577  inferredReturnTypes.assign({SizeType::get(context)});
1578  else
1579  inferredReturnTypes.assign({IndexType::get(context)});
1580  return success();
1581 }
1582 
1583 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1584  TypeRange r) {
1585  // SizeType is compatible with IndexType.
1586  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1587 }
1588 
1589 LogicalResult shape::NumElementsOp::verify() {
1590  return verifySizeOrIndexOp(*this);
1591 }
1592 
1593 //===----------------------------------------------------------------------===//
1594 // MaxOp
1595 //===----------------------------------------------------------------------===//
1596 
1597 OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1598  // If operands are equal, just propagate one.
1599  if (getLhs() == getRhs())
1600  return getLhs();
1601  return nullptr;
1602 }
1603 
1604 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1605  MLIRContext *context, std::optional<Location> location,
1606  MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1607  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1608  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1609  else
1610  inferredReturnTypes.assign({SizeType::get(context)});
1611  return success();
1612 }
1613 
1614 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1615  if (l.size() != 1 || r.size() != 1)
1616  return false;
1617  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1618  return true;
1619  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1620  return true;
1621  return false;
1622 }
1623 
1624 //===----------------------------------------------------------------------===//
1625 // MinOp
1626 //===----------------------------------------------------------------------===//
1627 
1628 OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1629  // If operands are equal, just propagate one.
1630  if (getLhs() == getRhs())
1631  return getLhs();
1632  return nullptr;
1633 }
1634 
1635 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1636  MLIRContext *context, std::optional<Location> location,
1637  MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1638  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1639  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1640  else
1641  inferredReturnTypes.assign({SizeType::get(context)});
1642  return success();
1643 }
1644 
1645 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1646  if (l.size() != 1 || r.size() != 1)
1647  return false;
1648  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1649  return true;
1650  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1651  return true;
1652  return false;
1653 }
1654 
1655 //===----------------------------------------------------------------------===//
1656 // MulOp
1657 //===----------------------------------------------------------------------===//
1658 
1659 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1660  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1661  if (!lhs)
1662  return nullptr;
1663  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1664  if (!rhs)
1665  return nullptr;
1666  APInt folded = lhs.getValue() * rhs.getValue();
1667  Type indexTy = IndexType::get(getContext());
1668  return IntegerAttr::get(indexTy, folded);
1669 }
1670 
1671 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1672  MLIRContext *context, std::optional<Location> location,
1673  MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1674  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1675  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1676  inferredReturnTypes.assign({SizeType::get(context)});
1677  else
1678  inferredReturnTypes.assign({IndexType::get(context)});
1679  return success();
1680 }
1681 
1682 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1683  // SizeType is compatible with IndexType.
1684  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1685 }
1686 
1687 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1688 
1689 //===----------------------------------------------------------------------===//
1690 // ShapeOfOp
1691 //===----------------------------------------------------------------------===//
1692 
1693 namespace {
1694 /// Replace shape_of(x) where x has a constant shape with a const_shape op.
1695 struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1697 
1698  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1699  PatternRewriter &rewriter) const override {
1700  auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1701  if (!type || !type.hasStaticShape())
1702  return failure();
1703  Location loc = op.getLoc();
1704  Value constShape =
1705  ConstShapeOp::create(rewriter, loc,
1706  rewriter.getIndexTensorAttr(type.getShape()))
1707  .getResult();
1708  if (constShape.getType() != op.getResult().getType())
1709  constShape = tensor::CastOp::create(rewriter, loc,
1710  op.getResult().getType(), constShape);
1711  rewriter.replaceOp(op, constShape);
1712  return success();
1713  }
1714 };
1715 
1716 // Canonicalize
1717 //
1718 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1719 // %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1720 //
1721 // to
1722 //
1723 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1724 // %1 = %shape
1725 //
1726 struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
1728 
1729  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1730  PatternRewriter &rewriter) const override {
1731  auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1732  if (!tensorReshapeOp)
1733  return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
1734  if (!isa<TensorType>(op.getType()))
1735  return rewriter.notifyMatchFailure(op, "result is not a tensor");
1736 
1737  // Operand 'shape' of 'tensor.reshape' may now be used as the result of
1738  // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1739  // formed IR, it may not be identical (dynamically vs statically shaped),
1740  // in which case it needs to be cast first using 'tensor.cast'.
1741  // Additionally, it may not have identical element type (i32 vs index)
1742  // while it has identical shaped type (dynamic vs static), in which case it
1743  // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
1744  // op result must be shape or extent tensor.
1745  Value shape = tensorReshapeOp.getShape();
1746 
1747  auto opTensorTy = cast<RankedTensorType>(op.getType());
1748  auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
1749 
1750  if (opTensorTy != shapeTensorTy) {
1751  if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1752  shape =
1753  tensor::CastOp::create(rewriter, op.getLoc(), opTensorTy, shape);
1754  else if (!isExtentTensorType(shapeTensorTy))
1755  shape = arith::IndexCastOp::create(rewriter, op.getLoc(), opTensorTy,
1756  shape);
1757  }
1758 
1759  rewriter.replaceOp(op, shape);
1760  return success();
1761  }
1762 };
1763 
1764 // Canonicalize
1765 // ```
1766 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1767 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1768 // ```
1769 // to
1770 // ```
1771 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1772 // ```
1773 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1775 
1776  LogicalResult matchAndRewrite(tensor::CastOp op,
1777  PatternRewriter &rewriter) const override {
1778  auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1779  if (!ty || ty.getRank() != 1)
1780  return failure();
1781 
1782  auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1783  if (!shapeOfOp)
1784  return failure();
1785 
1786  // Argument type must be ranked and must not conflict.
1787  auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1788  if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1789  return failure();
1790 
1791  rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1792  return success();
1793  }
1794 };
1795 } // namespace
1796 
1797 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1798  MLIRContext *context) {
1799  patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1800  ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1801  context);
1802 }
1803 
1804 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1805  MLIRContext *context, std::optional<Location> location,
1806  ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1807  if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1808  inferredReturnTypes.assign({ShapeType::get(context)});
1809  else {
1810  auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1811  int64_t rank =
1812  shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1813  Type indexTy = IndexType::get(context);
1814  Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1815  inferredReturnTypes.assign({extentTensorTy});
1816  }
1817  return success();
1818 }
1819 
1820 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1821  if (l.size() != 1 || r.size() != 1)
1822  return false;
1823  if (l == r)
1824  return true;
1825 
1826  Type lhs = l.front();
1827  Type rhs = r.front();
1828 
1829  if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1830  !llvm::isa<ShapeType, ShapedType>(rhs))
1831  return false;
1832 
1833  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1834  // Shape type is compatible with all other valid return types.
1835  return true;
1836 
1837  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1838  return true;
1839  return false;
1840 }
1841 
1842 LogicalResult shape::ShapeOfOp::verify() {
1843  return verifyShapeOrExtentTensorOp(*this);
1844 }
1845 
1846 //===----------------------------------------------------------------------===//
1847 // SizeToIndexOp
1848 //===----------------------------------------------------------------------===//
1849 
1850 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1851  // Constant values of both types, `shape.size` and `index`, are represented as
1852  // `IntegerAttr`s which makes constant folding simple.
1853  if (Attribute arg = adaptor.getArg())
1854  return arg;
1855  return OpFoldResult();
1856 }
1857 
1858 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1859  MLIRContext *context) {
1860  patterns.add<IndexToSizeToIndexCanonicalization>(context);
1861 }
1862 
1863 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1864  if (inputs.size() != 1 || outputs.size() != 1)
1865  return false;
1866  return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1867  llvm::isa<IndexType>(outputs[0]);
1868 }
1869 
1870 //===----------------------------------------------------------------------===//
1871 // YieldOp
1872 //===----------------------------------------------------------------------===//
1873 
1874 LogicalResult shape::YieldOp::verify() {
1875  auto *parentOp = (*this)->getParentOp();
1876  auto results = parentOp->getResults();
1877  auto operands = getOperands();
1878 
1879  if (parentOp->getNumResults() != getNumOperands())
1880  return emitOpError() << "number of operands does not match number of "
1881  "results of its parent";
1882  for (auto e : llvm::zip(results, operands))
1883  if (std::get<0>(e).getType() != std::get<1>(e).getType())
1884  return emitOpError() << "types mismatch between yield op and its parent";
1885 
1886  return success();
1887 }
1888 
1889 //===----------------------------------------------------------------------===//
1890 // SplitAtOp
1891 //===----------------------------------------------------------------------===//
1892 
1893 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1894  SmallVectorImpl<OpFoldResult> &results) {
1895  if (!adaptor.getOperand() || !adaptor.getIndex())
1896  return failure();
1897  auto shapeVec = llvm::to_vector<6>(
1898  llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1899  auto shape = llvm::ArrayRef(shapeVec);
1900  auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1901  // Verify that the split point is in the correct range.
1902  // TODO: Constant fold to an "error".
1903  int64_t rank = shape.size();
1904  if (-rank > splitPoint || splitPoint > rank)
1905  return failure();
1906  if (splitPoint < 0)
1907  splitPoint += shape.size();
1908  Builder builder(adaptor.getOperand().getContext());
1909  results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1910  results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1911  return success();
1912 }
1913 
1914 //===----------------------------------------------------------------------===//
1915 // ToExtentTensorOp
1916 //===----------------------------------------------------------------------===//
1917 
1918 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1919  if (!adaptor.getInput())
1920  return OpFoldResult();
1921  Builder builder(getContext());
1922  auto shape = llvm::to_vector<6>(
1923  llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1924  auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1925  builder.getIndexType());
1926  return DenseIntElementsAttr::get(type, shape);
1927 }
1928 
1929 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1930  if (inputs.size() != 1 || outputs.size() != 1)
1931  return false;
1932  if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1933  if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1934  inputTensor.getRank() != 1)
1935  return false;
1936  } else if (!llvm::isa<ShapeType>(inputs[0])) {
1937  return false;
1938  }
1939 
1940  TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1941  return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1942 }
1943 
1944 //===----------------------------------------------------------------------===//
1945 // ReduceOp
1946 //===----------------------------------------------------------------------===//
1947 
1948 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1949  ValueRange initVals) {
1950  OpBuilder::InsertionGuard g(builder);
1951  result.addOperands(shape);
1952  result.addOperands(initVals);
1953 
1954  Region *bodyRegion = result.addRegion();
1955  Block *bodyBlock = builder.createBlock(
1956  bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
1957 
1958  Type elementType;
1959  if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1960  elementType = tensorType.getElementType();
1961  else
1962  elementType = SizeType::get(builder.getContext());
1963  bodyBlock->addArgument(elementType, shape.getLoc());
1964 
1965  for (Value initVal : initVals) {
1966  bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
1967  result.addTypes(initVal.getType());
1968  }
1969 }
1970 
1971 LogicalResult ReduceOp::verify() {
1972  // Verify block arg types.
1973  Block &block = getRegion().front();
1974 
1975  // The block takes index, extent, and aggregated values as arguments.
1976  auto blockArgsCount = getInitVals().size() + 2;
1977  if (block.getNumArguments() != blockArgsCount)
1978  return emitOpError() << "ReduceOp body is expected to have "
1979  << blockArgsCount << " arguments";
1980 
1981  // The first block argument is the index and must always be of type `index`.
1982  if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1983  return emitOpError(
1984  "argument 0 of ReduceOp body is expected to be of IndexType");
1985 
1986  // The second block argument is the extent and must be of type `size` or
1987  // `index`, depending on whether the reduce operation is applied to a shape or
1988  // to an extent tensor.
1989  Type extentTy = block.getArgument(1).getType();
1990  if (llvm::isa<ShapeType>(getShape().getType())) {
1991  if (!llvm::isa<SizeType>(extentTy))
1992  return emitOpError("argument 1 of ReduceOp body is expected to be of "
1993  "SizeType if the ReduceOp operates on a ShapeType");
1994  } else {
1995  if (!llvm::isa<IndexType>(extentTy))
1996  return emitOpError(
1997  "argument 1 of ReduceOp body is expected to be of IndexType if the "
1998  "ReduceOp operates on an extent tensor");
1999  }
2000 
2001  for (const auto &type : llvm::enumerate(getInitVals()))
2002  if (block.getArgument(type.index() + 2).getType() != type.value().getType())
2003  return emitOpError() << "type mismatch between argument "
2004  << type.index() + 2
2005  << " of ReduceOp body and initial value "
2006  << type.index();
2007  return success();
2008 }
2009 
2010 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2011  // Parse operands.
2013  Type shapeOrExtentTensorType;
2014  if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
2016  parser.parseColonType(shapeOrExtentTensorType) ||
2017  parser.parseOptionalArrowTypeList(result.types))
2018  return failure();
2019 
2020  // Resolve operands.
2021  auto initVals = llvm::ArrayRef(operands).drop_front();
2022  if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
2023  result.operands) ||
2024  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2025  result.operands))
2026  return failure();
2027 
2028  // Parse the body.
2029  Region *body = result.addRegion();
2030  if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
2031  return failure();
2032 
2033  // Parse attributes.
2034  if (parser.parseOptionalAttrDict(result.attributes))
2035  return failure();
2036 
2037  return success();
2038 }
2039 
2040 void ReduceOp::print(OpAsmPrinter &p) {
2041  p << '(' << getShape() << ", " << getInitVals()
2042  << ") : " << getShape().getType();
2043  p.printOptionalArrowTypeList(getResultTypes());
2044  p << ' ';
2045  p.printRegion(getRegion());
2046  p.printOptionalAttrDict((*this)->getAttrs());
2047 }
2048 
2049 #define GET_OP_CLASSES
2050 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2051 
2052 #define GET_TYPEDEF_CLASSES
2053 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition: Shape.cpp:66
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition: Shape.cpp:974
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition: Shape.cpp:83
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition: Shape.cpp:96
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition: Shape.cpp:71
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:188
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:89
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
ValueTypeRange< ValueRange > type_range
Definition: ValueRange.h:418
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition: Traits.cpp:24
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:59
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
bool isExtentTensorType(Type)
Definition: Shape.cpp:44
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition: Shape.cpp:49
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:40
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.